Print true/false positive/negative stats

This commit is contained in:
Yohann D'ANELLO 2021-10-29 17:00:07 +02:00
parent 57d53a742f
commit d767e2029b
Signed by: ynerant
GPG Key ID: 3A75C55819C8CF85
1 changed files with 50 additions and 4 deletions

View File

@ -1,6 +1,7 @@
import argparse import argparse
import unicodedata import unicodedata
import sys import sys
from typing import Optional
import numpy as np import numpy as np
@ -23,6 +24,9 @@ def parse_args(argv: dict = None) -> argparse.Namespace:
help='Documents to read.', default=sys.stdin) help='Documents to read.', default=sys.stdin)
# Give similarity threshold. # Give similarity threshold.
parser.add_argument('similarity', nargs='?', type=float, help='Similarity threshold.', default=0.05) parser.add_argument('similarity', nargs='?', type=float, help='Similarity threshold.', default=0.05)
# Optional. Give statistics about true and false positive/negative rates, which take some time.
parser.add_argument('--stats', '-s', action='store_true',
help='Display some statistics.')
# Optional. Let to display a progress bar while generating and applying permutations, # Optional. Let to display a progress bar while generating and applying permutations,
# which is the most expensive state. # which is the most expensive state.
parser.add_argument('--progress', '-p', '--tqdm', action='store_true', parser.add_argument('--progress', '-p', '--tqdm', action='store_true',
@ -214,7 +218,8 @@ def jaccard_similarity(doc1: set, doc2: set) -> float:
return len(inter) / len(union) return len(inter) / len(union)
def parse(stream, similarity: float, display_tqdm: bool = False) -> None: def parse(stream, similarity: float, *, stats: bool = False, display_tqdm: bool = False) \
-> Optional[tuple[int, int, int, int]]:
""" """
Given a stream of documents (separated by line feeds) and a similarity threshold, Given a stream of documents (separated by line feeds) and a similarity threshold,
we display in standard output an estimation of document pairs that we display in standard output an estimation of document pairs that
@ -238,14 +243,43 @@ def parse(stream, similarity: float, display_tqdm: bool = False) -> None:
# Sort pairs for a nice output # Sort pairs for a nice output
candidate_pairs = sorted(candidate_pairs) candidate_pairs = sorted(candidate_pairs)
# Compute true and false positive counts
tp = 0
fp = 0
# For each document pair, compute true Jaccard similarity and display it # For each document pair, compute true Jaccard similarity and display it
shingles_set = [set(x for x in range(len(shingles)) if shingles[x, doc]) for doc in range(len(docs))]
for doc_a, doc_b in candidate_pairs: for doc_a, doc_b in candidate_pairs:
# Compute true jaccard similarity # Compute true jaccard similarity
shingles_a = set(x for x in range(len(shingles)) if shingles[x, doc_a]) shingles_a = shingles_set[doc_a]
shingles_b = set(x for x in range(len(shingles)) if shingles[x, doc_b]) shingles_b = shingles_set[doc_b]
d = jaccard_similarity(shingles_a, shingles_b) d = jaccard_similarity(shingles_a, shingles_b)
if d >= similarity: if d >= similarity:
print(f"{doc_a} {doc_b} {d:.06f}") print(f"{doc_a} {doc_b} {d:.06f}")
tp += 1
else:
fp += 1
if stats:
# Compute true and false negative counts, for validation only
tn = 0
fn = 0
for doc_a in range(len(docs)):
for doc_b in range(doc_a + 1, len(docs)):
# Compute true jaccard similarity
shingles_a = shingles_set[doc_a]
shingles_b = shingles_set[doc_b]
d = jaccard_similarity(shingles_a, shingles_b)
if d >= similarity and (doc_a, doc_b) not in candidate_pairs:
fn += 1
elif d < similarity and (doc_a, doc_b) not in candidate_pairs:
tn += 1
fp_rate = fp / (fp + tn)
tp_rate = tp / (tp + fn)
return tp, fp, tn, fn
def main(): def main():
@ -256,4 +290,16 @@ def main():
raise ValueError(f"Invalid similiarity value: {ns.similarity}") raise ValueError(f"Invalid similiarity value: {ns.similarity}")
# Analyse documents # Analyse documents
parse(ns.input, ns.similarity, ns.progress) output = parse(ns.input, ns.similarity, stats=ns.stats, display_tqdm=ns.progress)
if ns.stats:
tp, fp, tn, fn = output
print(f"True positive: {tp}", file=sys.stderr)
print(f"False positive: {tn}", file=sys.stderr)
print(f"True negative: {fp}", file=sys.stderr)
print(f"False negative: {fn}", file=sys.stderr)
tp_rate = tp / (tp + fn)
fp_rate = fp / (fp + tn)
print(f"True positive rate: {tp_rate:.06f}", file=sys.stderr)
print(f"False positive rate: {fp_rate:.06f}", file=sys.stderr)