Add option --progress to display the progress bar while generating permutations

This commit is contained in:
Yohann D'ANELLO 2021-10-29 14:34:17 +02:00
parent 78efd7fb2a
commit 310b43483e
Signed by: ynerant
GPG Key ID: 3A75C55819C8CF85
1 changed files with 15 additions and 6 deletions

View File

@ -1,5 +1,4 @@
import argparse
from tqdm import tqdm
import unicodedata
import sys
@ -15,6 +14,8 @@ def parse_args(argv: dict = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Exercise 1')
parser.add_argument('input', nargs='?', type=argparse.FileType('r'), help='Documents to read.', default=sys.stdin)
parser.add_argument('similarity', nargs='?', type=float, help='Similarity threshold.', default=0.05)
parser.add_argument('--progress', '-p', '--tqdm', action='store_true',
help='Display progress bar while calculating signature matrix.')
return parser.parse_args(argv[1:])
@ -55,12 +56,20 @@ def min_hash(doc: str, perm: list[str]) -> str:
return d
def compute_signature_matrix(shingles: np.ndarray, permutations_count: int) -> np.ndarray:
def compute_signature_matrix(shingles: np.ndarray, permutations_count: int, display_tqdm: bool = False) -> np.ndarray:
shingles_count, docs_count = shingles.shape
signature_matrix = np.inf * np.ones((permutations_count, docs_count))
for permutation_id in tqdm(range(permutations_count)):
permutations_iterator = range(permutations_count)
if display_tqdm:
try:
from tqdm import tqdm
permutations_iterator = tqdm(permutations_iterator)
except ImportError:
print("tqdm is not installed. Please install tqdm before using --tqdm option.")
for permutation_id in permutations_iterator:
permutation = np.random.permutation(shingles)
signature_matrix[permutation_id] = permutation.argmax(0)
@ -76,7 +85,7 @@ def jaccard_similarity(doc1: set, doc2: set) -> float:
return len(inter) / len(union)
def parse(stream, similarity: float) -> None:
def parse(stream, similarity: float, display_tqdm: bool = False) -> None:
docs = [line.rstrip('\n') for line in stream]
docs = [normalize(doc) for doc in docs] # Remove special characters and normalize accents
@ -94,7 +103,7 @@ def parse(stream, similarity: float) -> None:
rows *= 2
threshold = (1 / bands) ** (1 / rows)
signature = compute_signature_matrix(shingles, bands * rows)
signature = compute_signature_matrix(shingles, bands * rows, display_tqdm)
candidate_pairs = set()
@ -128,4 +137,4 @@ def parse(stream, similarity: float) -> None:
def main():
ns = parse_args()
parse(ns.input, ns.similarity)
parse(ns.input, ns.similarity, ns.progress)