diff --git a/algods/algods.py b/algods/algods.py index 86db147..213736b 100644 --- a/algods/algods.py +++ b/algods/algods.py @@ -1,5 +1,4 @@ import argparse -from tqdm import tqdm import unicodedata import sys @@ -15,6 +14,8 @@ def parse_args(argv: dict = None) -> argparse.Namespace: parser = argparse.ArgumentParser(description='Exercise 1') parser.add_argument('input', nargs='?', type=argparse.FileType('r'), help='Documents to read.', default=sys.stdin) parser.add_argument('similarity', nargs='?', type=float, help='Similarity threshold.', default=0.05) + parser.add_argument('--progress', '-p', '--tqdm', action='store_true', + help='Display progress bar while calculating signature matrix.') return parser.parse_args(argv[1:]) @@ -55,12 +56,20 @@ def min_hash(doc: str, perm: list[str]) -> str: return d -def compute_signature_matrix(shingles: np.ndarray, permutations_count: int) -> np.ndarray: +def compute_signature_matrix(shingles: np.ndarray, permutations_count: int, display_tqdm: bool = False) -> np.ndarray: shingles_count, docs_count = shingles.shape signature_matrix = np.inf * np.ones((permutations_count, docs_count)) - for permutation_id in tqdm(range(permutations_count)): + permutations_iterator = range(permutations_count) + if display_tqdm: + try: + from tqdm import tqdm + permutations_iterator = tqdm(permutations_iterator) + except ImportError: + print("tqdm is not installed. Please install tqdm before using --tqdm option.") + + for permutation_id in permutations_iterator: permutation = np.random.permutation(shingles) signature_matrix[permutation_id] = permutation.argmax(0) @@ -76,7 +85,7 @@ def jaccard_similarity(doc1: set, doc2: set) -> float: return len(inter) / len(union) -def parse(stream, similarity: float) -> None: +def parse(stream, similarity: float, display_tqdm: bool = False) -> None: docs = [line.rstrip('\n') for line in stream] docs = [normalize(doc) for doc in docs] # Remove special characters and normalize accents @@ -94,7 +103,7 @@ def parse(stream, similarity: float) -> None: rows *= 2 threshold = (1 / bands) ** (1 / rows) - signature = compute_signature_matrix(shingles, bands * rows) + signature = compute_signature_matrix(shingles, bands * rows, display_tqdm) candidate_pairs = set() @@ -128,4 +137,4 @@ def parse(stream, similarity: float) -> None: def main(): ns = parse_args() - parse(ns.input, ns.similarity) + parse(ns.input, ns.similarity, ns.progress)