Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed nan returned for ndcg, rprec if no relevant docs retrieved #51

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions trectools/trec_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def get_rprec(self, depth=1000, per_query=False, trec_eval=True, removeUnjudged=
selection = pd.merge(topX, relevant_docs[["query","docid","rel"]], how="left")
selection = selection[~selection["rel"].isnull()]

rprec_per_query = selection.groupby("query")["docid"].count() / n_relevant_docs
rprec_per_query = selection.groupby("query")["docid"].count().div(n_relevant_docs, fill_value=0)
rprec_per_query.name = label
rprec_per_query = rprec_per_query.reset_index().set_index("query")

Expand Down Expand Up @@ -524,7 +524,7 @@ def get_ndcg(self, depth=1000, per_query=False, trec_eval=True, removeUnjudged=F
# DCG is the sum of individual's contribution
dcg_per_query = selection[["query", label]].groupby("query").sum()
idcg_per_query = perfect_ranking[["query",label]].groupby("query").sum()
ndcg_per_query = dcg_per_query / idcg_per_query
ndcg_per_query = dcg_per_query.div(idcg_per_query, fill_value=0.0)

if per_query:
return ndcg_per_query
Expand Down
2 changes: 1 addition & 1 deletion unittests/files/qrel1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
2 0 doc2_8 0
2 0 doc2_9 0
3 0 doc3_0 0
3 0 doc3_1 0
3 0 doc3_1 0
3 0 doc3_2 0
3 0 doc3_3 0
3 0 doc3_4 0
Expand Down
15 changes: 15 additions & 0 deletions unittests/files/r5.run
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
1 Q0 doc1_0 1 1.0 NR
1 Q0 doc1_1 2 0.9 NR
1 Q0 doc1_3 3 0.8 NR
1 Q0 doc1_4 4 0.7 NR
1 Q0 doc1_5 5 0.6 NR
2 Q0 doc2_3 1 1.0 NR
2 Q0 doc2_4 2 0.9 NR
2 Q0 doc2_5 3 0.8 NR
2 Q0 doc2_6 4 0.7 NR
2 Q0 doc2_7 5 0.6 NR
3 Q0 doc3_1 1 1.0 NR
3 Q0 doc3_2 2 0.9 NR
3 Q0 doc3_3 3 0.8 NR
3 Q0 doc3_4 4 0.7 NR
3 Q0 doc3_5 5 0.6 NR
15 changes: 15 additions & 0 deletions unittests/files/r6.run
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
1 Q0 doc1_10 1 1.0 NE
1 Q0 doc1_11 2 0.9 NE
1 Q0 doc1_13 3 0.8 NE
1 Q0 doc1_14 4 0.7 NE
1 Q0 doc1_15 5 0.6 NE
2 Q0 doc2_13 1 1.0 NE
2 Q0 doc2_14 2 0.9 NE
2 Q0 doc2_15 3 0.8 NE
2 Q0 doc2_16 4 0.7 NE
2 Q0 doc2_17 5 0.6 NE
3 Q0 doc3_11 1 1.0 NE
3 Q0 doc3_12 2 0.9 NE
3 Q0 doc3_13 3 0.8 NE
3 Q0 doc3_14 4 0.7 NE
3 Q0 doc3_15 5 0.6 NE
44 changes: 43 additions & 1 deletion unittests/testtreceval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import numpy as np
from trectools import TrecRun, TrecQrel, TrecEval


Expand All @@ -14,10 +15,18 @@ def setUp(self):

# Contains the first 30 documents for the first 10 topics in input.uic0301
run3 = TrecRun("./unittests/files/input.uic0301_top30")

# All documents retrieved are not relevant to qrel1.txt
no_rel_run = TrecRun("./unittests/files/r5.run")
# All documents retrieved have no labels
no_labels_run = TrecRun("./unittests/files/r6.run")

self.common_topics = ["303", "307", "310", "314", "320", "322", "325", "330", "336", "341"]
self.teval1 = TrecEval(run1, qrels1)
self.teval2 = TrecEval(run2, qrels2)
self.teval3 = TrecEval(run3, qrels2)
self.teval_no_rel = TrecEval(no_rel_run, qrels1)
self.teval_no_labels = TrecEval(no_labels_run, qrels1)

def tearDown(self):
pass
Expand Down Expand Up @@ -118,6 +127,39 @@ def test_get_recall(self):
for v, c in zip(values, correct_results):
self.assertAlmostEqual(v, c, places=4)

def test_no_relevant_retrieved(self):
result = self.teval_no_rel.evaluate_all(per_query=True)

# All of these should be 0
metrics = []
depths = [5, 10, 15, 20, 30, 100, 200, 500, 100]
metrics += [f"NDCG_{i}" for i in depths]
metrics += [f"P_{i}" for i in depths]
metrics += [f"R_{i}" for i in depths]
metrics += ["map", "Rprec", "num_rel_ret"]
# recip_rank could arguably be included

metrics = result.data[result.data["metric"].isin(metrics)]

self.assertTrue(np.allclose(metrics["value"].values.astype(np.float32), 0))

def test_no_labels_retrieved(self):
result = self.teval_no_labels.evaluate_all(per_query=True)

# All of these should be 0
metrics = []
depths = [5, 10, 15, 20, 30, 100, 200, 500, 100]
metrics += [f"NDCG_{i}" for i in depths]
metrics += [f"P_{i}" for i in depths]
metrics += [f"R_{i}" for i in depths]
metrics += ["map", "Rprec", "num_rel_ret"]
# recip_rank could arguably be included

metrics = result.data[result.data["metric"].isin(metrics)]
self.assertTrue(np.allclose(metrics["value"].values.astype(np.float32), 0))




if __name__ == '__main__':
unittest.main()
unittest.main()