Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
AlyaGomaa committed Oct 22, 2024
2 parents 768b924 + a2e0f14 commit 69a7c8e
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
19 changes: 10 additions & 9 deletions metrics/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def F1(self):
"""

precision = self.metrics['precision']
recall = self.metrics['recall']
recall = self.TPR(log=False)
if precision + recall == 0:
f1 = 0
else:
Expand Down Expand Up @@ -207,17 +207,21 @@ def TPR(self, log=True):
tpr = 0
else:
tpr = self.metrics['TP'] / (self.metrics['TP'] + self.metrics['FN'])

if log:
self.log(f"{self.tool}: TPR: ", tpr)
return tpr

def FNR(self):
def FNR(self) -> float:
"""
FNR = 1- TPR
FNR = FN / (FN + TP)
prints the false negative rate of a given tool
:return: float
"""
fnr = 1 - self.TPR(log=False)
try:
fnr = self.metrics["FN"] / (self.metrics["FN"] + self.metrics["TP"])
except ZeroDivisionError:
fnr = 0

self.log(f"{self.tool}: FNR: ", fnr)
return fnr

Expand Down Expand Up @@ -262,14 +266,11 @@ def calc_all_metrics(self) -> Dict[str, float]:
self.FNR,
self.TPR,
self.TNR,
self.recall,
self.precision,
self.F1,
self.accuracy,
self.MCC,
):
res.update({metric.__name__ : metric()})
return res
# def __del__(self):
# if hasattr(self, 'db'):
# self.db.close()

Empty file.
File renamed without changes.
File renamed without changes.
9 changes: 4 additions & 5 deletions scripts/slips_metrics_getter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from argparse import ArgumentParser

from plot.plot import Plot
from scripts.extracted_levels import extracted_threat_levels
from scripts.extracted_gt_tw_labels import gt_tw_labels
from scripts.extracted_scores.extracted_levels import extracted_threat_levels
from scripts.extracted_scores.extracted_gt_tw_labels import gt_tw_labels
from metrics.calculator import Calculator


Expand Down Expand Up @@ -69,7 +69,7 @@ def print_extremes(threshold_with_min_max: Dict[str, Dict]):
"""
Print the extreme values for each metric and their corresponding thresholds.
"""
print("Best/Worst thresholds so far")
print("Below are the minimum and maximum of the 4 error metrics with the threshold that resulted in them.")
for metric, info in threshold_with_min_max.items():
print(f"{metric}:")
print(f" Min value: {info['min_value']}, Threshold: {info['min_threshold']}")
Expand Down Expand Up @@ -232,14 +232,13 @@ def main():
# calc.calc_all_metrics()
experiment_metrics = {
'MCC': calc.MCC(),
'recall': calc.recall(),
'precision': calc.precision(),
'F1': calc.F1(),
'FPR': calc.FPR(),
'TPR': calc.TPR(),
'FNR': calc.FNR(),
'TNR': calc.TNR(),
'accuracy': calc.accuracy(),
'F1': calc.F1(),
}
experiment_metrics.update(confusion_matrix)
metrics[threshold].update({exp: experiment_metrics})
Expand Down

0 comments on commit 69a7c8e

Please sign in to comment.