This commit is contained in:
Valentin Samir 2016-06-26 11:16:41 +02:00
parent 3e80a018dd
commit ac5f359063
4 changed files with 23 additions and 29 deletions

View File

@ -26,11 +26,11 @@ class AuthUser(object):
def test_password(self, password): def test_password(self, password):
"""test `password` agains the user""" """test `password` agains the user"""
raise NotImplemented() raise NotImplementedError()
def attributs(self): def attributs(self):
"""return a dict of user attributes""" """return a dict of user attributes"""
raise NotImplemented() raise NotImplementedError()
class DummyAuthUser(AuthUser): class DummyAuthUser(AuthUser):

View File

@ -44,7 +44,7 @@ def get_user_ticket_request(service):
def get_pgt(): def get_pgt():
(httpd_thread, host, port) = utils.PGTUrlHandler.run() (host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port) service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service) (user, ticket) = get_user_ticket_request(service)
@ -326,7 +326,7 @@ class ValidateTestCase(TestCase):
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_view_ok(self): def test_validate_view_ok(self):
(user, ticket) = get_user_ticket_request(self.service) ticket = get_user_ticket_request(self.service)[1]
client = Client() client = Client()
response = client.get('/validate', {'ticket': ticket.value, 'service': self.service}) response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
@ -334,7 +334,7 @@ class ValidateTestCase(TestCase):
self.assertEqual(response.content, b'yes\ntest\n') self.assertEqual(response.content, b'yes\ntest\n')
def test_validate_view_badservice(self): def test_validate_view_badservice(self):
(user, ticket) = get_user_ticket_request(self.service) ticket = get_user_ticket_request(self.service)[1]
client = Client() client = Client()
response = client.get( response = client.get(
@ -345,7 +345,7 @@ class ValidateTestCase(TestCase):
self.assertEqual(response.content, b'no\n') self.assertEqual(response.content, b'no\n')
def test_validate_view_badticket(self): def test_validate_view_badticket(self):
(user, ticket) = get_user_ticket_request(self.service) get_user_ticket_request(self.service)
client = Client() client = Client()
response = client.get( response = client.get(
@ -369,7 +369,7 @@ class ValidateServiceTestCase(TestCase):
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_service_view_ok(self): def test_validate_service_view_ok(self):
(user, ticket) = get_user_ticket_request(self.service) ticket = get_user_ticket_request(self.service)[1]
client = Client() client = Client()
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service}) response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
@ -404,7 +404,7 @@ class ValidateServiceTestCase(TestCase):
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
def test_validate_service_view_badservice(self): def test_validate_service_view_badservice(self):
(user, ticket) = get_user_ticket_request(self.service) ticket = get_user_ticket_request(self.service)[1]
client = Client() client = Client()
bad_service = "https://www.example.org" bad_service = "https://www.example.org"
@ -421,7 +421,7 @@ class ValidateServiceTestCase(TestCase):
self.assertEqual(error[0].text, bad_service) self.assertEqual(error[0].text, bad_service)
def test_validate_service_view_badticket_goodprefix(self): def test_validate_service_view_badticket_goodprefix(self):
(user, ticket) = get_user_ticket_request(self.service) get_user_ticket_request(self.service)
client = Client() client = Client()
bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
@ -438,7 +438,7 @@ class ValidateServiceTestCase(TestCase):
self.assertEqual(error[0].text, 'ticket not found') self.assertEqual(error[0].text, 'ticket not found')
def test_validate_service_view_badticket_badprefix(self): def test_validate_service_view_badticket_badprefix(self):
(user, ticket) = get_user_ticket_request(self.service) get_user_ticket_request(self.service)
client = Client() client = Client()
bad_ticket = "RANDOM" bad_ticket = "RANDOM"
@ -455,10 +455,10 @@ class ValidateServiceTestCase(TestCase):
self.assertEqual(error[0].text, bad_ticket) self.assertEqual(error[0].text, bad_ticket)
def test_validate_service_view_ok_pgturl(self): def test_validate_service_view_ok_pgturl(self):
(httpd_thread, host, port) = utils.PGTUrlHandler.run() (host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port) service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service) ticket = get_user_ticket_request(service)[1]
client = Client() client = Client()
response = client.get( response = client.get(
@ -480,7 +480,7 @@ class ValidateServiceTestCase(TestCase):
def test_validate_service_pgturl_bad_proxy_callback(self): def test_validate_service_pgturl_bad_proxy_callback(self):
self.service_pattern.proxy_callback = False self.service_pattern.proxy_callback = False
self.service_pattern.save() self.service_pattern.save()
(user, ticket) = get_user_ticket_request(self.service) ticket = get_user_ticket_request(self.service)[1]
client = Client() client = Client()
response = client.get( response = client.get(

View File

@ -80,9 +80,9 @@ def update_url(url, params):
query = dict(parse_qsl(url_parts[4])) query = dict(parse_qsl(url_parts[4]))
query.update(params) query.update(params)
url_parts[4] = urlencode(query) url_parts[4] = urlencode(query)
for i in range(len(url_parts)): for i, url_part in enumerate(url_parts):
if not isinstance(url_parts[i], bytes): if not isinstance(url_part, bytes):
url_parts[i] = url_parts[i].encode('utf-8') url_parts[i] = url_part.encode('utf-8')
return urlunparse(url_parts).decode('utf-8') return urlunparse(url_parts).decode('utf-8')
@ -155,7 +155,7 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
params = dict(parse_qsl(url.query)) params = dict(parse_qsl(url.query))
PGTUrlHandler.PARAMS.update(params) PGTUrlHandler.PARAMS.update(params)
def log_message(self, format, *args): def log_message(self, template, *args):
return return
@staticmethod @staticmethod

View File

@ -63,12 +63,12 @@ class AttributesMixin(object):
class LogoutMixin(object): class LogoutMixin(object):
"""destroy CAS session utils""" """destroy CAS session utils"""
def logout(self, all=False): def logout(self, all_session=False):
"""effectively destroy CAS session""" """effectively destroy CAS session"""
session_nb = 0 session_nb = 0
username = self.request.session.get("username") username = self.request.session.get("username")
if username: if username:
if all: if all_session:
logger.info("Logging out user %s from all of they sessions." % username) logger.info("Logging out user %s from all of they sessions." % username)
else: else:
logger.info("Logging out user %s." % username) logger.info("Logging out user %s." % username)
@ -86,8 +86,8 @@ class LogoutMixin(object):
# if user not found in database, flush the session anyway # if user not found in database, flush the session anyway
self.request.session.flush() self.request.session.flush()
# If all is set logout user from alternative sessions # If all_session is set logout user from alternative sessions
if all: if all_session:
for user in models.User.objects.filter(username=username): for user in models.User.objects.filter(username=username):
session = SessionStore(session_key=user.session_key) session = SessionStore(session_key=user.session_key)
session.flush() session.flush()
@ -198,10 +198,7 @@ class LoginView(View, LogoutMixin):
def init_post(self, request): def init_post(self, request):
self.request = request self.request = request
self.service = request.POST.get('service') self.service = request.POST.get('service')
if request.POST.get('renew') and request.POST['renew'] != "False": self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
self.renew = True
else:
self.renew = False
self.gateway = request.POST.get('gateway') self.gateway = request.POST.get('gateway')
self.method = request.POST.get('method') self.method = request.POST.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = 'HTTP_X_AJAX' in request.META
@ -285,10 +282,7 @@ class LoginView(View, LogoutMixin):
def init_get(self, request): def init_get(self, request):
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
if request.GET.get('renew') and request.GET['renew'] != "False": self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
self.renew = True
else:
self.renew = False
self.gateway = request.GET.get('gateway') self.gateway = request.GET.get('gateway')
self.method = request.GET.get('method') self.method = request.GET.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = 'HTTP_X_AJAX' in request.META