From c7c5151acf1041013a9238db54759650d3bef11c Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Thu, 30 Jun 2016 23:13:53 +0200 Subject: [PATCH] Tests comments and move http server handlers from cas_server.utils to cas_server.tests.utils --- cas_server/tests/test_models.py | 15 +++++-- cas_server/tests/test_view.py | 65 +++++++++++++++++++++++------- cas_server/tests/utils.py | 71 ++++++++++++++++++++++++++++++++- cas_server/utils.py | 68 +------------------------------ 4 files changed, 133 insertions(+), 86 deletions(-) diff --git a/cas_server/tests/test_models.py b/cas_server/tests/test_models.py index 5b001a7..372fd5a 100644 --- a/cas_server/tests/test_models.py +++ b/cas_server/tests/test_models.py @@ -9,8 +9,7 @@ from datetime import timedelta from importlib import import_module from cas_server import models -from cas_server import utils -from cas_server.tests.utils import get_auth_client +from cas_server.tests.utils import get_auth_client, HttpParamsHandler from cas_server.tests.mixin import UserModels, BaseServicePattern SessionStore = import_module(settings.SESSION_ENGINE).SessionStore @@ -125,22 +124,32 @@ class TicketTestCase(TestCase, UserModels, BaseServicePattern): def test_clean_old_service_ticket(self): """test tickets clean_old_entries""" + # ge an authenticated client client = get_auth_client() + # get the user associated to the client user = self.get_user(client) + # generate a ticket for that client, waiting for validation self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern) + # generate another ticket for those validation time has expired self.get_ticket( user, models.ServiceTicket, self.service, self.service_pattern, validity_expired=True ) - (httpd, host, port) = utils.HttpParamsHandler.run()[0:3] + (httpd, host, port) = HttpParamsHandler.run()[0:3] service = "http://%s:%s" % (host, port) + # generate a ticket with SLO having timeout reach self.get_ticket( user, models.ServiceTicket, service, self.service_pattern, timeout_expired=True, validate=True, single_log_out=True ) + # there should be 3 tickets in the db self.assertEqual(len(models.ServiceTicket.objects.all()), 3) + # we call the clean_old_entries method that should delete validated non SLO ticket and + # expired non validated ticket and send SLO for SLO expired ticket before deleting then models.ServiceTicket.clean_old_entries() params = httpd.PARAMS + # we successfully got a SLO request self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest']) + # only 1 ticket remain in the db self.assertEqual(len(models.ServiceTicket.objects.all()), 1) diff --git a/cas_server/tests/test_view.py b/cas_server/tests/test_view.py index f1bdb85..8c8a900 100644 --- a/cas_server/tests/test_view.py +++ b/cas_server/tests/test_view.py @@ -21,7 +21,9 @@ from cas_server.tests.utils import ( get_user_ticket_request, get_pgt, get_proxy_ticket, - get_validated_ticket + get_validated_ticket, + HttpParamsHandler, + Http404Handler ) from cas_server.tests.mixin import BaseServicePattern, XmlContent @@ -697,7 +699,7 @@ class LogoutTestCase(TestCase): # test normal SLO # setup a simple one request http server - (httpd, host, port) = utils.HttpParamsHandler.run()[0:3] + (httpd, host, port) = HttpParamsHandler.run()[0:3] # build a service url depending on which port the http server has binded service = "http://%s:%s" % (host, port) # get a ticket requested by client and being validated by the service @@ -709,7 +711,7 @@ class LogoutTestCase(TestCase): # text SLO with a single_log_out_callback # setup a simple one request http server - (httpd, host, port) = utils.HttpParamsHandler.run()[0:3] + (httpd, host, port) = HttpParamsHandler.run()[0:3] # set the default test service pattern to use the http server port for SLO requests. # in fact, this single_log_out_callback parametter is usefull to implement SLO # for non http service like imap or ftp @@ -1273,7 +1275,7 @@ class ValidateServiceTestCase(TestCase, XmlContent): def test_validate_service_view_ok_pgturl(self): """test the retrieval of a ProxyGrantingTicket""" # start a simple on request http server - (httpd, host, port) = utils.HttpParamsHandler.run()[0:3] + (httpd, host, port) = HttpParamsHandler.run()[0:3] # construct the service from it service = "http://%s:%s" % (host, port) @@ -1304,7 +1306,7 @@ class ValidateServiceTestCase(TestCase, XmlContent): def test_validate_service_pgturl_sslerror(self): """test the retrieval of a ProxyGrantingTicket with a SSL error on the pgtUrl""" - (host, port) = utils.HttpParamsHandler.run()[1:3] + (host, port) = HttpParamsHandler.run()[1:3] # is fact the service listen on http and not https raisin a SSL Protocol Error # but other SSL/TLS error should behave the same service = "https://%s:%s" % (host, port) @@ -1329,7 +1331,7 @@ class ValidateServiceTestCase(TestCase, XmlContent): test the retrieval on a ProxyGrantingTicket then to pgtUrl return a http error. PGT creation should be aborted but the ticket still be valid """ - (host, port) = utils.Http404Handler.run()[1:3] + (host, port) = Http404Handler.run()[1:3] service = "http://%s:%s" % (host, port) ticket = get_user_ticket_request(service)[1] @@ -1424,8 +1426,10 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent): """tests for the proxy view""" def setUp(self): """preparing test context""" + # we prepare a bunch a service url and service patterns for tests self.setup_service_patterns(proxy=True) + # set the default service pattern to localhost to be able to retrieve PGT self.service = 'http://127.0.0.1' self.service_pattern = models.ServicePattern.objects.create( name="localhost", @@ -1433,6 +1437,7 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent): proxy=True, proxy_callback=True ) + # transmit all attributes models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) def test_validate_proxy_ok(self): @@ -1440,13 +1445,20 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent): Get a PGT, get a proxy ticket, validate it. Validation should succeed and show the proxy service URL. """ + # we directrly get a ProxyGrantingTicket params = get_pgt() - # get a proxy ticket + # We try get a proxy ticket with our PGT client1 = Client() - response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service}) + # for what we send a GET request to /proxy with ge PGT and the target service for which + # we want a ProxyTicket to. + response = client1.get( + '/proxy', + {'pgt': params['pgtId'], 'targetService': "https://www.example.com"} + ) self.assertEqual(response.status_code, 200) + # we should sucessfully reteive a PT root = etree.fromstring(response.content) sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) self.assertTrue(sucess) @@ -1458,16 +1470,21 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent): self.assertEqual(len(proxy_ticket), 1) proxy_ticket = proxy_ticket[0].text - # validate the proxy ticket + # validate the proxy ticket with the service for which is was emitted client2 = Client() - response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service}) + response = client2.get( + '/proxyValidate', + {'ticket': proxy_ticket, 'service': "https://www.example.com"} + ) + # validation should succeed and return settings.CAS_TEST_USER as username + # and settings.CAS_TEST_ATTRIBUTES as attributes root = self.assert_success( response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES ) - # check that the proxy is send to the end service + # in the PT validation response, it should have the service url of the PGY proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"}) self.assertEqual(len(proxies), 1) proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"}) @@ -1476,6 +1493,7 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent): def test_validate_proxy_bad_pgt(self): """Try to get a ProxyTicket with a bad PGT. The PT generation should fail""" + # we directrly get a ProxyGrantingTicket params = get_pgt() client = Client() response = client.get( @@ -1496,8 +1514,10 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent): Try to get a ProxyTicket for a denied service and a service that do not allow PT. The PT generation should fail. """ + # we directrly get a ProxyGrantingTicket params = get_pgt() + # try to get a PT for a denied service client1 = Client() response = client1.get( '/proxy', @@ -1509,7 +1529,7 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent): "https://www.example.org" ) - # service do not allow proxy ticket + # try to get a PT for a service that do not allow PT self.service_pattern.proxy = False self.service_pattern.save() @@ -1531,16 +1551,20 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent): def test_proxy_unauthorized_user(self): """ Try to get a PT for services that do not allow the current user: - * first with a service that restrict allower username + * first with a service that restrict allowed username * second with a service requiring somes conditions on the user attributes * third with a service using a particular user attribute as username All this tests should fail """ + # we directrly get a ProxyGrantingTicket params = get_pgt() for service in [ + # do ot allow the test username self.service_restrict_user_fail, + # require the 'nom' attribute to be 'toto' self.service_filter_fail, + # want to use the non-exitant 'uid' attribute as username self.service_field_needed_fail ]: client = Client() @@ -1548,6 +1572,7 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent): '/proxy', {'pgt': params['pgtId'], 'targetService': service} ) + # PT generation should fail self.assert_error( response, "UNAUTHORIZED_USER", @@ -1575,8 +1600,10 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): """tests for the proxy view""" def setUp(self): """preparing test context""" + # we prepare a bunch a service url and service patterns for tests self.setup_service_patterns(proxy=True) + # special service pattern for retrieving a PGT self.service_pgt = 'http://127.0.0.1' self.service_pattern_pgt = models.ServicePattern.objects.create( name="localhost", @@ -1589,6 +1616,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): service_pattern=self.service_pattern_pgt ) + # template for the XML POST need to be send to validate a ticket using SAML 1.1 xml_template = """ @@ -1607,6 +1635,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): def assert_success(self, response, username, original_attributes): """assert ticket validation success""" self.assertEqual(response.status_code, 200) + # on validation success, the response should have a StatusCode set to Success root = etree.fromstring(response.content) success = root.xpath( "//samlp:StatusCode", @@ -1615,6 +1644,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): self.assertEqual(len(success), 1) self.assertTrue(success[0].attrib['Value'].endswith(":Success")) + # the user username should be return whithin tags user = root.xpath( "//samla:NameIdentifier", namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"} @@ -1622,6 +1652,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): self.assertTrue(user) self.assertEqual(user[0].text, username) + # the returned attributes should match original_attributes attributes = root.xpath( "//samla:AttributeStatement/samla:Attribute", namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"} @@ -1641,6 +1672,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): def assert_error(self, response, code, msg=None): """assert ticket validation error""" self.assertEqual(response.status_code, 200) + # on error the status code value should be the one provider in `code` root = etree.fromstring(response.content) error = root.xpath( "//samlp:StatusCode", @@ -1648,6 +1680,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): ) self.assertEqual(len(error), 1) self.assertTrue(error[0].attrib['Value'].endswith(":%s" % code)) + # it may have an error message if msg is not None: self.assertEqual(error[0].text, msg) @@ -1656,12 +1689,15 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): test with a valid (ticket, service), with a ST and a PT, the username and all attributes are transmited""" tickets = [ + # return a ServiceTicket (standard ticket) waiting for validation get_user_ticket_request(self.service)[1], + # return a PT waiting for validation get_proxy_ticket(self.service) ] for ticket in tickets: client = Client() + # we send the POST validation requests response = client.post( '/samlValidate?TARGET=%s' % self.service, self.xml_template % { @@ -1671,6 +1707,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): }, content_type="text/xml; encoding='utf-8'" ) + # and it should succeed self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES) def test_saml_ok_user_field(self): @@ -1734,7 +1771,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): ) def test_saml_bad_target(self): - """test with a valid(ticket, service), but using a bad target""" + """test with a valid ticket, but using a bad target, validation should fail""" bad_target = "https://www.example.org" ticket = get_user_ticket_request(self.service)[1] diff --git a/cas_server/tests/utils.py b/cas_server/tests/utils.py index ef06c93..a0f4aa8 100644 --- a/cas_server/tests/utils.py +++ b/cas_server/tests/utils.py @@ -3,10 +3,13 @@ from cas_server.default_settings import settings from django.test import Client +import cgi +from threading import Thread from lxml import etree +from six.moves import BaseHTTPServer +from six.moves.urllib.parse import urlparse, parse_qsl from cas_server import models -from cas_server import utils def copy_form(form): @@ -70,7 +73,7 @@ def get_validated_ticket(service): def get_pgt(): """return a dict contening a service, user and PGT ticket for this service""" - (httpd, host, port) = utils.HttpParamsHandler.run()[0:3] + (httpd, host, port) = HttpParamsHandler.run()[0:3] service = "http://%s:%s" % (host, port) (user, ticket) = get_user_ticket_request(service)[:2] @@ -100,3 +103,67 @@ def get_proxy_ticket(service): proxy_ticket = proxy_ticket[0].text ticket = models.ProxyTicket.objects.get(value=proxy_ticket) return ticket + + +class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler): + """ + A simple http server that return 200 on GET or POST + and store GET or POST parameters. Used in unit tests + """ + + def do_GET(self): + """Called on a GET request on the BaseHTTPServer""" + self.send_response(200) + self.send_header(b"Content-type", "text/plain") + self.end_headers() + self.wfile.write(b"ok") + url = urlparse(self.path) + params = dict(parse_qsl(url.query)) + self.server.PARAMS = params + + def do_POST(self): + """Called on a POST request on the BaseHTTPServer""" + ctype, pdict = cgi.parse_header(self.headers.get('content-type')) + if ctype == 'multipart/form-data': + postvars = cgi.parse_multipart(self.rfile, pdict) + elif ctype == 'application/x-www-form-urlencoded': + length = int(self.headers.get('content-length')) + postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1) + else: + postvars = {} + self.server.PARAMS = postvars + + def log_message(self, *args): + """silent any log message""" + return + + @classmethod + def run(cls): + """Run a BaseHTTPServer using this class as handler""" + server_class = BaseHTTPServer.HTTPServer + httpd = server_class(("127.0.0.1", 0), cls) + (host, port) = httpd.socket.getsockname() + + def lauch(): + """routine to lauch in a background thread""" + httpd.handle_request() + httpd.server_close() + + httpd_thread = Thread(target=lauch) + httpd_thread.daemon = True + httpd_thread.start() + return (httpd, host, port) + + +class Http404Handler(HttpParamsHandler): + """A simple http server that always return 404 not found. Used in unit tests""" + def do_GET(self): + """Called on a GET request on the BaseHTTPServer""" + self.send_response(404) + self.send_header(b"Content-type", "text/plain") + self.end_headers() + self.wfile.write(b"error 404 not found") + + def do_POST(self): + """Called on a POST request on the BaseHTTPServer""" + return self.do_GET() diff --git a/cas_server/utils.py b/cas_server/utils.py index f85c25e..3be2bad 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -23,10 +23,8 @@ import hashlib import crypt import base64 import six -import cgi -from threading import Thread + from importlib import import_module -from six.moves import BaseHTTPServer from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode @@ -151,70 +149,6 @@ def gen_saml_id(): return _gen_ticket('_') -class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler): - """ - A simple http server that return 200 on GET or POST - and store GET or POST parameters. Used in unit tests - """ - - def do_GET(self): - """Called on a GET request on the BaseHTTPServer""" - self.send_response(200) - self.send_header(b"Content-type", "text/plain") - self.end_headers() - self.wfile.write(b"ok") - url = urlparse(self.path) - params = dict(parse_qsl(url.query)) - self.server.PARAMS = params - - def do_POST(self): - """Called on a POST request on the BaseHTTPServer""" - ctype, pdict = cgi.parse_header(self.headers.get('content-type')) - if ctype == 'multipart/form-data': - postvars = cgi.parse_multipart(self.rfile, pdict) - elif ctype == 'application/x-www-form-urlencoded': - length = int(self.headers.get('content-length')) - postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1) - else: - postvars = {} - self.server.PARAMS = postvars - - def log_message(self, *args): - """silent any log message""" - return - - @classmethod - def run(cls): - """Run a BaseHTTPServer using this class as handler""" - server_class = BaseHTTPServer.HTTPServer - httpd = server_class(("127.0.0.1", 0), cls) - (host, port) = httpd.socket.getsockname() - - def lauch(): - """routine to lauch in a background thread""" - httpd.handle_request() - httpd.server_close() - - httpd_thread = Thread(target=lauch) - httpd_thread.daemon = True - httpd_thread.start() - return (httpd, host, port) - - -class Http404Handler(HttpParamsHandler): - """A simple http server that always return 404 not found. Used in unit tests""" - def do_GET(self): - """Called on a GET request on the BaseHTTPServer""" - self.send_response(404) - self.send_header(b"Content-type", "text/plain") - self.end_headers() - self.wfile.write(b"error 404 not found") - - def do_POST(self): - """Called on a POST request on the BaseHTTPServer""" - return self.do_GET() - - class LdapHashUserPassword(object): """Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""