Add option --progress to display the progress bar while generating permutations
This commit is contained in:
parent
78efd7fb2a
commit
310b43483e
|
@ -1,5 +1,4 @@
|
|||
import argparse
|
||||
from tqdm import tqdm
|
||||
import unicodedata
|
||||
import sys
|
||||
|
||||
|
@ -15,6 +14,8 @@ def parse_args(argv: dict = None) -> argparse.Namespace:
|
|||
parser = argparse.ArgumentParser(description='Exercise 1')
|
||||
parser.add_argument('input', nargs='?', type=argparse.FileType('r'), help='Documents to read.', default=sys.stdin)
|
||||
parser.add_argument('similarity', nargs='?', type=float, help='Similarity threshold.', default=0.05)
|
||||
parser.add_argument('--progress', '-p', '--tqdm', action='store_true',
|
||||
help='Display progress bar while calculating signature matrix.')
|
||||
|
||||
return parser.parse_args(argv[1:])
|
||||
|
||||
|
@ -55,12 +56,20 @@ def min_hash(doc: str, perm: list[str]) -> str:
|
|||
return d
|
||||
|
||||
|
||||
def compute_signature_matrix(shingles: np.ndarray, permutations_count: int) -> np.ndarray:
|
||||
def compute_signature_matrix(shingles: np.ndarray, permutations_count: int, display_tqdm: bool = False) -> np.ndarray:
|
||||
shingles_count, docs_count = shingles.shape
|
||||
|
||||
signature_matrix = np.inf * np.ones((permutations_count, docs_count))
|
||||
|
||||
for permutation_id in tqdm(range(permutations_count)):
|
||||
permutations_iterator = range(permutations_count)
|
||||
if display_tqdm:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
permutations_iterator = tqdm(permutations_iterator)
|
||||
except ImportError:
|
||||
print("tqdm is not installed. Please install tqdm before using --tqdm option.")
|
||||
|
||||
for permutation_id in permutations_iterator:
|
||||
permutation = np.random.permutation(shingles)
|
||||
signature_matrix[permutation_id] = permutation.argmax(0)
|
||||
|
||||
|
@ -76,7 +85,7 @@ def jaccard_similarity(doc1: set, doc2: set) -> float:
|
|||
return len(inter) / len(union)
|
||||
|
||||
|
||||
def parse(stream, similarity: float) -> None:
|
||||
def parse(stream, similarity: float, display_tqdm: bool = False) -> None:
|
||||
docs = [line.rstrip('\n') for line in stream]
|
||||
docs = [normalize(doc) for doc in docs] # Remove special characters and normalize accents
|
||||
|
||||
|
@ -94,7 +103,7 @@ def parse(stream, similarity: float) -> None:
|
|||
rows *= 2
|
||||
threshold = (1 / bands) ** (1 / rows)
|
||||
|
||||
signature = compute_signature_matrix(shingles, bands * rows)
|
||||
signature = compute_signature_matrix(shingles, bands * rows, display_tqdm)
|
||||
|
||||
candidate_pairs = set()
|
||||
|
||||
|
@ -128,4 +137,4 @@ def parse(stream, similarity: float) -> None:
|
|||
|
||||
def main():
|
||||
ns = parse_args()
|
||||
parse(ns.input, ns.similarity)
|
||||
parse(ns.input, ns.similarity, ns.progress)
|
||||
|
|
Loading…
Reference in New Issue