Skip to content

Commit

Permalink
added SetR
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed May 30, 2021
1 parent b35c9d2 commit 16347d2
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ir_measures/measures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def register(measure, aliases=[], name=None):
from .rbp import RBP, _RBP
from .rprec import Rprec, RPrec, _Rprec
from .rr import RR, MRR, _RR
from .setp import SetP, _SetP
from .set_measures import SetP, _SetP, SetR, _SetR
from .success import Success, _Success

# enable from "ir_measures.measures import *" --- on purpuse, do not include _-prefixed versions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ class _SetP(measures.Measure):
'rel': measures.ParamInfo(dtype=int, default=1, desc='minimum relevance score to be considered relevant (inclusive)')
}


SetP = _SetP()
measures.register(SetP)


class _SetR(measures.Measure):
"""
The Set Recall (SetR); i.e., the number of relevant docs divided by the total number of relevant documents
"""
__name__ = 'SetR'
NAME = __name__
SUPPORTED_PARAMS = {
'rel': measures.ParamInfo(dtype=int, default=1, desc='minimum relevance score to be considered relevant (inclusive)')
}

SetR = _SetR()
measures.register(SetR)
4 changes: 4 additions & 0 deletions ir_measures/providers/pytrec_eval_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class PytrecEvalProvider(providers.Provider):
measures._NumQ(),
measures._NumRel(rel=Choices(1)), # for some reason, relevance_level doesn't flow through to num_rel, so can only support rel=1
measures._SetP(rel=Any()),
measures._SetR(rel=Any()),
measures._Success(rel=Any(), cutoff=Any()),
measures._IPrec(recall=Any()),
measures._infAP(rel=Any()),
Expand Down Expand Up @@ -115,6 +116,9 @@ def _build_invokers(self, measures, qrels):
elif measure.NAME == 'SetP':
invocation_key = (measure['rel'],)
measure_str = f'set_P'
elif measure.NAME == 'SetR':
invocation_key = (measure['rel'],)
measure_str = f'set_recall'
elif measure.NAME == 'Success':
invocation_key = (measure['rel'],)
measure_str = f'success_{measure["cutoff"]}'
Expand Down
2 changes: 1 addition & 1 deletion ir_measures/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def parse_trec_measure(measure: str) -> List['Measure']:
'ndcg_rel': (None, None, None),
'recip_rank': (ir_measures.RR, None, None),
'recall': (ir_measures.R, 'cutoff', [5, 10, 15, 20, 30, 100, 200, 500, 1000]),
'set_recall': (None, None, None),
'set_recall': (ir_measures.SetR, None, None),
'utility': (None, None, None),
'set_relative_P': (None, None, None),
'num_ret': (ir_measures.NumRet, None, None),
Expand Down
47 changes: 47 additions & 0 deletions test/test_pytrec_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,53 @@ def test_SetP(self):
self.assertEqual(result[1].value, 0)
self.assertEqual(provider.calc_aggregate([measure], qrels, run)[measure], 0)

def test_SetR(self):
qrels = list(ir_measures.read_trec_qrels('''
0 0 D0 0
0 0 D1 1
0 0 D2 2
0 0 D3 2
0 0 D4 0
1 0 D0 1
1 0 D3 2
1 0 D5 0
'''))
run = list(ir_measures.read_trec_run('''
0 0 D0 1 0.8 run
0 0 D2 2 0.7 run
0 0 D1 3 0.3 run
0 0 D3 4 0.4 run
0 0 D4 5 0.1 run
1 0 D1 1 0.8 run
1 0 D4 2 0.7 run
1 0 D3 3 0.3 run
1 0 D2 4 0.4 run
'''))
provider = ir_measures.providers.PytrecEvalProvider()
measure = ir_measures.SetR(rel=1)
result = list(provider.iter_calc([measure], qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 1.)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, .5)
self.assertEqual(provider.calc_aggregate([measure], qrels, run)[measure], 0.75)

measure = ir_measures.SetR(rel=2)
result = list(provider.iter_calc([measure], qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 1.0)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 1.0)
self.assertEqual(provider.calc_aggregate([measure], qrels, run)[measure], 1.0)

measure = ir_measures.SetR(rel=3)
result = list(provider.iter_calc([measure], qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 0)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0)
self.assertEqual(provider.calc_aggregate([measure], qrels, run)[measure], 0)

def test_IPrec(self):
qrels = list(ir_measures.read_trec_qrels('''
0 0 D0 1
Expand Down
4 changes: 3 additions & 1 deletion test/test_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import itertools
import ir_measures
from ir_measures import AP, P, nDCG, NumRel, NumRelRet, Bpref, NumQ, RR, Rprec, NumRet, IPrec
from ir_measures import *

class TestUtil(unittest.TestCase):

Expand All @@ -19,6 +19,8 @@ def test_parse_trec_measure(self):
'ndcg_cut_10': [nDCG@10],
'ndcg_cut_5,10': [nDCG@5, nDCG@10],
'ndcg_cut': [nDCG@5, nDCG@10, nDCG@15, nDCG@20, nDCG@30, nDCG@100, nDCG@200, nDCG@500, nDCG@1000],
'set_P': [SetP],
'set_recall': [SetR],
'official': [P@5, P@10, P@15, P@20, P@30, P@100, P@200, P@500, P@1000, Rprec, Bpref, IPrec@0.0, IPrec@0.1, IPrec@0.2, IPrec@0.3, IPrec@0.4, IPrec@0.5, IPrec@0.6, IPrec@0.7, IPrec@0.8, IPrec@0.9, IPrec@1.0, AP, NumQ, NumRel, NumRelRet, NumRet, RR],
}
for case in cases:
Expand Down

0 comments on commit 16347d2

Please sign in to comment.