Add unit test for the utils function check_password

This commit is contained in:
Valentin Samir 2016-06-26 21:44:41 +02:00
parent 6faeaad57e
commit 2fac47f0b1
3 changed files with 66 additions and 10 deletions

View File

@ -15,7 +15,6 @@ from django.contrib.auth import get_user_model
try: # pragma: no cover try: # pragma: no cover
import MySQLdb import MySQLdb
import MySQLdb.cursors import MySQLdb.cursors
import crypt
from utils import check_password from utils import check_password
except ImportError: except ImportError:
MySQLdb = None MySQLdb = None

View File

@ -3,6 +3,7 @@ from .default_settings import settings
from django.test import TestCase from django.test import TestCase
from django.test import Client from django.test import Client
import six
from lxml import etree from lxml import etree
from cas_server import models from cas_server import models
@ -59,6 +60,60 @@ def get_pgt():
return params return params
class CheckPasswordCase(TestCase):
def setUp(self):
self.password1 = utils.gen_saml_id()
self.password2 = utils.gen_saml_id()
if not isinstance(self.password1, bytes):
self.password1 = self.password1.encode("utf8")
self.password2 = self.password2.encode("utf8")
def test_setup(self):
self.assertIsInstance(self.password1, bytes)
self.assertIsInstance(self.password2, bytes)
def test_plain(self):
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
def test_crypt(self):
if six.PY3:
hashed_password1 = utils.crypt.crypt(
self.password1.decode("utf8"),
"$6$UVVAQvrMyXMF3FF3"
).encode("utf8")
else:
hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
def test_ldap_ssha(self):
salt = b"UVVAQvrMyXMF3FF3"
hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
self.assertIsInstance(hashed_password1, bytes)
self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
def test_hex_md5(self):
hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
def test_hox_sha512(self):
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
self.assertTrue(
utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
)
self.assertFalse(
utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
)
class LoginTestCase(TestCase): class LoginTestCase(TestCase):
def setUp(self): def setUp(self):

View File

@ -177,6 +177,7 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
httpd_thread.start() httpd_thread.start()
return (httpd_thread, host, port) return (httpd_thread, host, port)
class LdapHashUserPassword(object): class LdapHashUserPassword(object):
"""Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html""" """Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""
@ -204,8 +205,6 @@ class LdapHashUserPassword(object):
b"{SSHA512}": 64, b"{SSHA512}": 64,
} }
class BadScheme(ValueError): class BadScheme(ValueError):
pass pass
@ -217,9 +216,9 @@ class LdapHashUserPassword(object):
@classmethod @classmethod
def _raise_bad_scheme(cls, scheme, valid, msg): def _raise_bad_scheme(cls, scheme, valid, msg):
valid_schemes = [s for s in valid] valid_schemes = [s.decode() for s in valid]
valid_schemes.sort() valid_schemes.sort()
raise cls.BadScheme(msg % (scheme, ", ".join(valid_schemes))) raise cls.BadScheme(msg % (scheme, u", ".join(valid_schemes)))
@classmethod @classmethod
def _test_scheme(cls, scheme): def _test_scheme(cls, scheme):
@ -258,7 +257,9 @@ class LdapHashUserPassword(object):
elif salt is not None: elif salt is not None:
cls._test_scheme_salt(scheme) cls._test_scheme_salt(scheme)
try: try:
return scheme + base64.b64encode(cls._schemes_to_hash[scheme](password + salt).digest() + salt) return scheme + base64.b64encode(
cls._schemes_to_hash[scheme](password + salt).digest() + salt
)
except KeyError: except KeyError:
if six.PY3: if six.PY3:
password = password.decode(charset) password = password.decode(charset)
@ -272,13 +273,12 @@ class LdapHashUserPassword(object):
@classmethod @classmethod
def get_scheme(cls, hashed_passord): def get_scheme(cls, hashed_passord):
if not hashed_passord[0] == b'{' or not b'}' in hashed_passord: if not hashed_passord[0] == b'{'[0] or b'}' not in hashed_passord:
raise cls.BadHash("%r should start with the scheme enclosed with { }" % hashed_passord) raise cls.BadHash("%r should start with the scheme enclosed with { }" % hashed_passord)
scheme = hashed_passord.split(b'}', 1)[0] scheme = hashed_passord.split(b'}', 1)[0]
scheme = scheme.upper() + b"}" scheme = scheme.upper() + b"}"
return scheme return scheme
@classmethod @classmethod
def get_salt(cls, hashed_passord): def get_salt(cls, hashed_passord):
scheme = cls.get_scheme(hashed_passord) scheme = cls.get_scheme(hashed_passord)
@ -294,7 +294,6 @@ class LdapHashUserPassword(object):
return hashed_passord[cls._schemes_to_len[scheme]:] return hashed_passord[cls._schemes_to_len[scheme]:]
def check_password(method, password, hashed_password, charset): def check_password(method, password, hashed_password, charset):
if not isinstance(password, six.binary_type): if not isinstance(password, six.binary_type):
password = password.encode(charset) password = password.encode(charset)
@ -325,6 +324,9 @@ def check_password(method, password, hashed_password, charset):
method.startswith("hex_") and method.startswith("hex_") and
method[4:] in {"md5", "sha1", "sha224", "sha256", "sha384", "sha512"} method[4:] in {"md5", "sha1", "sha224", "sha256", "sha384", "sha512"}
): ):
return getattr(hashlib, method[4:])(password).hexdigest() == hashed_password.lower() return getattr(
hashlib,
method[4:]
)(password).hexdigest().encode("ascii") == hashed_password.lower()
else: else:
raise ValueError("Unknown password method check %r" % method) raise ValueError("Unknown password method check %r" % method)