Cleaner BaseHTTPRequestHandler

This commit is contained in:
Valentin Samir 2016-06-29 20:51:30 +02:00
parent e5efdadde0
commit d4b9d66051
3 changed files with 32 additions and 13 deletions

View File

@ -863,7 +863,7 @@ class ValidateServiceTestCase(TestCase, XmlContent):
def test_validate_service_view_ok_pgturl(self): def test_validate_service_view_ok_pgturl(self):
"""test the retrieval of a ProxyGrantingTicket""" """test the retrieval of a ProxyGrantingTicket"""
(host, port) = utils.PGTUrlHandler.run()[1:3] (httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port) service = "http://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1] ticket = get_user_ticket_request(service)[1]
@ -873,7 +873,7 @@ class ValidateServiceTestCase(TestCase, XmlContent):
'/serviceValidate', '/serviceValidate',
{'ticket': ticket.value, 'service': service, 'pgtUrl': service} {'ticket': ticket.value, 'service': service, 'pgtUrl': service}
) )
pgt_params = utils.PGTUrlHandler.PARAMS.copy() pgt_params = httpd.PARAMS
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content) root = etree.fromstring(response.content)
@ -887,7 +887,7 @@ class ValidateServiceTestCase(TestCase, XmlContent):
def test_validate_service_pgturl_sslerror(self): def test_validate_service_pgturl_sslerror(self):
"""test the retrieval of a ProxyGrantingTicket with a SSL error on the pgtUrl""" """test the retrieval of a ProxyGrantingTicket with a SSL error on the pgtUrl"""
(host, port) = utils.PGTUrlHandler.run()[1:3] (host, port) = utils.HttpParamsHandler.run()[1:3]
service = "https://%s:%s" % (host, port) service = "https://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1] ticket = get_user_ticket_request(service)[1]
@ -907,7 +907,7 @@ class ValidateServiceTestCase(TestCase, XmlContent):
test the retrieval on a ProxyGrantingTicket then to pgtUrl return a http error. test the retrieval on a ProxyGrantingTicket then to pgtUrl return a http error.
PGT creation should be aborted but the ticket still be valid PGT creation should be aborted but the ticket still be valid
""" """
(host, port) = utils.PGTUrlHandler404.run()[1:3] (host, port) = utils.Http404Handler.run()[1:3]
service = "http://%s:%s" % (host, port) service = "http://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1] ticket = get_user_ticket_request(service)[1]

View File

@ -55,14 +55,14 @@ def get_user_ticket_request(service):
def get_pgt(): def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service""" """return a dict contening a service, user and PGT ticket for this service"""
(host, port) = utils.PGTUrlHandler.run()[1:3] (httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port) service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service) (user, ticket) = get_user_ticket_request(service)[:2]
client = Client() client = Client()
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
params = utils.PGTUrlHandler.PARAMS.copy() params = httpd.PARAMS
params["service"] = service params["service"] = service
params["user"] = user params["user"] = user

View File

@ -23,6 +23,7 @@ import hashlib
import crypt import crypt
import base64 import base64
import six import six
import cgi
from threading import Thread from threading import Thread
from importlib import import_module from importlib import import_module
from six.moves import BaseHTTPServer from six.moves import BaseHTTPServer
@ -150,9 +151,11 @@ def gen_saml_id():
return _gen_ticket('_') return _gen_ticket('_')
class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""A simple http server that return 200 on GET and store GET parameters. Used in unit tests""" """
PARAMS = {} A simple http server that return 200 on GET or POST
and store GET or POST parameters. Used in unit tests
"""
def do_GET(self): def do_GET(self):
"""Called on a GET request on the BaseHTTPServer""" """Called on a GET request on the BaseHTTPServer"""
@ -162,7 +165,19 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
self.wfile.write(b"ok") self.wfile.write(b"ok")
url = urlparse(self.path) url = urlparse(self.path)
params = dict(parse_qsl(url.query)) params = dict(parse_qsl(url.query))
PGTUrlHandler.PARAMS.update(params) self.server.PARAMS = params
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
if ctype == 'multipart/form-data':
postvars = cgi.parse_multipart(self.rfile, pdict)
elif ctype == 'application/x-www-form-urlencoded':
length = int(self.headers.get('content-length'))
postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
else:
postvars = {}
self.server.PARAMS = postvars
def log_message(self, *args): def log_message(self, *args):
"""silent any log message""" """silent any log message"""
@ -183,10 +198,10 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
httpd_thread = Thread(target=lauch) httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True httpd_thread.daemon = True
httpd_thread.start() httpd_thread.start()
return (httpd_thread, host, port) return (httpd, host, port)
class PGTUrlHandler404(PGTUrlHandler): class Http404Handler(HttpParamsHandler):
"""A simple http server that always return 404 not found. Used in unit tests""" """A simple http server that always return 404 not found. Used in unit tests"""
def do_GET(self): def do_GET(self):
"""Called on a GET request on the BaseHTTPServer""" """Called on a GET request on the BaseHTTPServer"""
@ -195,6 +210,10 @@ class PGTUrlHandler404(PGTUrlHandler):
self.end_headers() self.end_headers()
self.wfile.write(b"error 404 not found") self.wfile.write(b"error 404 not found")
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
return self.do_GET()
class LdapHashUserPassword(object): class LdapHashUserPassword(object):
"""Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html""" """Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""