# -*- coding: utf-8 -*- # This program is distributed in the hope that it will be useful, but WITHOUT # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for # more details. # # You should have received a copy of the GNU General Public License version 3 # along with this program; if not, write to the Free Software Foundation, Inc., 51 # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # # (c) 2016 Valentin Samir """Some utils functions for tests""" from cas_server.default_settings import settings import django from django.test import Client from django.template import loader from django.utils import timezone if django.VERSION < (1, 8): from django.template import Context else: Context = lambda x:x import cgi import six from threading import Thread from lxml import etree from six.moves import BaseHTTPServer from six.moves.urllib.parse import urlparse, parse_qsl, parse_qs from datetime import timedelta from cas_server import models from cas_server import utils def return_unicode(string, charset): """make `string` a unicode if `string` is a unicode or bytes encoded with `charset`""" if not isinstance(string, six.text_type): return string.decode(charset) else: return string def return_bytes(string, charset): """ make `string` a bytes encoded with `charset` if `string` is a unicode or bytes encoded with `charset` """ if isinstance(string, six.text_type): return string.encode(charset) else: return string def copy_form(form): """Copy form value into a dict""" params = {} for field in form: if field.value(): params[field.name] = field.value() else: params[field.name] = "" return params def get_login_page_params(client=None): """Return a client and the POST params for the client to login""" if client is None: client = Client() response = client.get('/login') params = copy_form(response.context["form"]) return client, params def get_auth_client(**update): """return a authenticated client""" client, params = get_login_page_params() params["username"] = settings.CAS_TEST_USER params["password"] = settings.CAS_TEST_PASSWORD params.update(update) response = client.post('/login', params) assert client.session.get("authenticated") if params.get("service"): return (client, response) else: return client def get_user_ticket_request(service): """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)""" client = get_auth_client() response = client.get("/login", {"service": service}) ticket_value = response['Location'].split('ticket=')[-1] user = models.User.objects.get( username=settings.CAS_TEST_USER, session_key=client.session.session_key ) ticket = models.ServiceTicket.objects.get(value=ticket_value) return (user, ticket, client) def get_validated_ticket(service): """Return a tick that has being already validated. Used to test SLO""" (ticket, auth_client) = get_user_ticket_request(service)[1:3] client = Client() response = client.get('/validate', {'ticket': ticket.value, 'service': service}) assert (response.status_code == 200) assert (response.content == b'yes\ntest\n') ticket = models.ServiceTicket.objects.get(value=ticket.value) return (auth_client, ticket) def get_pgt(): """return a dict contening a service, user and PGT ticket for this service""" (httpd, host, port) = HttpParamsHandler.run()[0:3] service = "http://%s:%s" % (host, port) (user, ticket) = get_user_ticket_request(service)[:2] client = Client() client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) params = httpd.PARAMS params["service"] = service params["user"] = user return params def get_proxy_ticket(service): """Return a ProxyTicket waiting for validation""" params = get_pgt() # get a proxy ticket client = Client() response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service}) root = etree.fromstring(response.content) proxy_ticket = root.xpath( "//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"} ) proxy_ticket = proxy_ticket[0].text ticket = models.ProxyTicket.objects.get(value=proxy_ticket) return ticket class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler): """ A simple http server that return 200 on GET or POST and store GET or POST parameters. Used in unit tests """ def do_GET(self): """Called on a GET request on the BaseHTTPServer""" self.send_response(200) self.send_header(b"Content-type", "text/plain") self.end_headers() self.wfile.write(b"ok") url = urlparse(self.path) params = dict(parse_qsl(url.query)) self.server.PARAMS = params def do_POST(self): """Called on a POST request on the BaseHTTPServer""" ctype, pdict = cgi.parse_header(self.headers.get('content-type')) if ctype == 'multipart/form-data': postvars = cgi.parse_multipart(self.rfile, pdict) elif ctype == 'application/x-www-form-urlencoded': length = int(self.headers.get('content-length')) postvars = parse_qs(self.rfile.read(length), keep_blank_values=1) else: postvars = {} self.server.PARAMS = postvars def log_message(self, *args): """silent any log message""" return @classmethod def run(cls, port=0): """Run a BaseHTTPServer using this class as handler""" server_class = BaseHTTPServer.HTTPServer httpd = server_class(("127.0.0.1", port), cls) (host, port) = httpd.socket.getsockname() def lauch(): """routine to lauch in a background thread""" httpd.handle_request() httpd.server_close() httpd_thread = Thread(target=lauch) httpd_thread.daemon = True httpd_thread.start() return (httpd, host, port) class Http404Handler(HttpParamsHandler): """A simple http server that always return 404 not found. Used in unit tests""" def do_GET(self): """Called on a GET request on the BaseHTTPServer""" self.send_response(404) self.send_header(b"Content-type", "text/plain") self.end_headers() self.wfile.write(b"error 404 not found") def do_POST(self): """Called on a POST request on the BaseHTTPServer""" return self.do_GET() class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler): """A dummy CAS that validate for only one (service, ticket) used in federated mode tests""" def test_params(self): """check that internal and provided (service, ticket) matches""" if ( self.server.ticket is not None and self.params.get("service").encode("ascii") == self.server.service and self.params.get("ticket").encode("ascii") == self.server.ticket ): self.server.ticket = None return True else: return False def send_headers(self, code, content_type): """send http headers""" self.send_response(code) self.send_header("Content-type", content_type) self.end_headers() def do_GET(self): """Called on a GET request on the BaseHTTPServer""" url = urlparse(self.path) self.params = dict(parse_qsl(url.query)) if url.path == "/validate": self.send_headers(200, "text/plain; charset=utf-8") if self.test_params(): self.wfile.write(b"yes\n" + self.server.username + b"\n") self.server.ticket = None else: self.wfile.write(b"no\n") elif url.path in { '/serviceValidate', '/serviceValidate', '/p3/serviceValidate', '/p3/proxyValidate' }: self.send_headers(200, "text/xml; charset=utf-8") if self.test_params(): template = loader.get_template('cas_server/serviceValidate.xml') context = Context({ 'username': self.server.username, 'attributes': self.server.attributes }) self.wfile.write(return_bytes(template.render(context), "utf8")) else: template = loader.get_template('cas_server/serviceValidateError.xml') context = Context({ 'code': 'BAD_SERVICE_TICKET', 'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket) }) self.wfile.write(return_bytes(template.render(context), "utf8")) else: self.return_404() def do_POST(self): """Called on a POST request on the BaseHTTPServer""" url = urlparse(self.path) self.params = dict(parse_qsl(url.query)) if url.path == "/samlValidate": self.send_headers(200, "text/xml; charset=utf-8") length = int(self.headers.get('content-length')) root = etree.fromstring(self.rfile.read(length)) auth_req = root.getchildren()[1].getchildren()[0] ticket = auth_req.getchildren()[0].text.encode("ascii") if ( self.server.ticket is not None and self.params.get("TARGET").encode("ascii") == self.server.service and ticket == self.server.ticket ): self.server.ticket = None template = loader.get_template('cas_server/samlValidate.xml') context = Context({ 'IssueInstant': timezone.now().isoformat(), 'expireInstant': (timezone.now() + timedelta(seconds=60)).isoformat(), 'Recipient': self.server.service, 'ResponseID': utils.gen_saml_id(), 'username': self.server.username, 'attributes': self.server.attributes, }) self.wfile.write(return_bytes(template.render(context), "utf8")) else: template = loader.get_template('cas_server/samlValidateError.xml') context = Context({ 'IssueInstant': timezone.now().isoformat(), 'ResponseID': utils.gen_saml_id(), 'code': 'BAD_SERVICE_TICKET', 'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket) }) self.wfile.write(return_bytes(template.render(context), "utf8")) else: self.return_404() def return_404(self): """return a 404 error""" self.send_headers(404, "text/plain; charset=utf-8") self.wfile.write("not found") def log_message(self, *args): """silent any log message""" return @classmethod def run(cls, service, ticket, username, attributes, port=0): """Run a BaseHTTPServer using this class as handler""" server_class = BaseHTTPServer.HTTPServer httpd = server_class(("127.0.0.1", port), cls) httpd.service = service httpd.ticket = ticket httpd.username = username httpd.attributes = attributes (host, port) = httpd.socket.getsockname() def lauch(): """routine to lauch in a background thread""" httpd.handle_request() httpd.server_close() httpd_thread = Thread(target=lauch) httpd_thread.daemon = True httpd_thread.start() return (httpd, host, port) def logout_request(ticket): """build a SLO request XML, ready to be send""" return u""" %(ticket)s """ % \ { 'id': utils.gen_saml_id(), 'datetime': timezone.now().isoformat(), 'ticket': ticket }