Skip to content

Commit

Permalink
added SetF
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed May 30, 2021
1 parent 16347d2 commit 3e96cac
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 22 deletions.
2 changes: 1 addition & 1 deletion ir_measures/measures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def register(measure, aliases=[], name=None):
from .rbp import RBP, _RBP
from .rprec import Rprec, RPrec, _Rprec
from .rr import RR, MRR, _RR
from .set_measures import SetP, _SetP, SetR, _SetR
from .set_measures import SetP, _SetP, SetR, _SetR, SetF, _SetF
from .success import Success, _Success

# enable from "ir_measures.measures import *" --- on purpuse, do not include _-prefixed versions,
Expand Down
14 changes: 14 additions & 0 deletions ir_measures/measures/set_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,17 @@ class _SetR(measures.Measure):

SetR = _SetR()
measures.register(SetR)

class _SetF(measures.Measure):
"""
The Set F measure (SetF); i.e., the harmonic mean of SetP and SetR
"""
__name__ = 'SetF'
NAME = __name__
SUPPORTED_PARAMS = {
'rel': measures.ParamInfo(dtype=int, default=1, desc='minimum relevance score to be considered relevant (inclusive)'),
'beta': measures.ParamInfo(dtype=float, default=1., desc='relative importance of R to P in the harmonic mean'),
}

SetF = _SetF()
measures.register(SetF)
58 changes: 38 additions & 20 deletions ir_measures/providers/pytrec_eval_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,83 +57,101 @@ def _evaluator(self, measures, qrels):

def _build_invokers(self, measures, qrels):
invocations = {}
setf_count = 0
for measure in measures:
match_str = None
if measure.NAME == 'P':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = f'P_{measure["cutoff"]}'
elif measure.NAME == 'RR':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = f'recip_rank'
elif measure.NAME == 'Rprec':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = f'Rprec'
elif measure.NAME == 'AP':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
if measure['cutoff'] is NOT_PROVIDED:
measure_str = f'map'
else:
measure_str = f'map_cut_{measure["cutoff"]}'
elif measure.NAME == 'infAP':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = f'infAP'
elif measure.NAME == 'nDCG':
# Doesn't matter where this goes... Put it in an existing invocation, or just (1,) if none yet exist
if invocations:
invocation_key = next(iter(invocations))
else:
invocation_key = (1,)
invocation_key = (1, 0)
if measure['cutoff'] is NOT_PROVIDED:
measure_str = f'ndcg'
else:
measure_str = f'ndcg_cut_{measure["cutoff"]}'
elif measure.NAME == 'R':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = f'recall_{measure["cutoff"]}'
elif measure.NAME == 'Bpref':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = f'bpref'
elif measure.NAME == 'NumRet':
if measure['rel'] is NOT_PROVIDED:
# Doesn't matter where this goes... Put it in an existing invocation, or just (1,) if none yet exist
if invocations:
invocation_key = next(iter(invocations))
else:
invocation_key = (1,)
invocation_key = (1, 0)
measure_str = 'num_ret'
else:
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = 'num_rel_ret'
elif measure.NAME == 'NumQ':
# Doesn't matter where this goes... Put it in an existing invocation, or just (1,) if none yet exist
if invocations:
invocation_key = next(iter(invocations))
else:
invocation_key = (1,)
invocation_key = (1, 0)
measure_str = 'num_q'
elif measure.NAME == 'NumRel':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = 'num_rel'
elif measure.NAME == 'SetF':
# set_F is strange (or buggy?) in both trec_eval and pytrec_eval. It only accepts
# the first beta argument it's given, which is why we use the setf_count approach
# to handle multiple invocations. It also is always reported as the name set_F by
# pytrec_eval, so we need different measure_str and match_str here.
invocation_key = (measure['rel'], setf_count)
setf_count += 1
measure_str = f'set_F_{measure["beta"]}'
match_str = 'set_F'
if measure['beta'] == 1.:
measure_str = f'set_F'
else:
measure_str = f'set_F_{measure["beta"]}'
elif measure.NAME == 'SetP':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = f'set_P'
elif measure.NAME == 'SetR':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = f'set_recall'
elif measure.NAME == 'Success':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = f'success_{measure["cutoff"]}'
elif measure.NAME == 'IPrec':
invocation_key = (measure['rel'],)
invocation_key = (measure['rel'], 0)
measure_str = f'iprec_at_recall_{measure["recall"]:.2f}'
else:
raise ValueError(f'unsupported measure {measure}')

if match_str is None:
match_str = measure_str

if invocation_key not in invocations:
invocations[invocation_key] = {}
invocations[invocation_key][measure_str] = measure
invocations[invocation_key][match_str] = (measure, measure_str)

invokers = []
for (rel_level, ), measure_map in invocations.items():
for (rel_level, it), measure_map in invocations.items():
invokers.append(PytrecEvalInvoker(self.pytrec_eval, qrels, measure_map, rel_level))

return invokers
Expand All @@ -160,14 +178,14 @@ def iter_calc(self, run):

class PytrecEvalInvoker:
def __init__(self, pte, qrels, measure_map, rel_level):
self.evaluator = pte.RelevanceEvaluator(qrels, measure_map.keys(), relevance_level=rel_level)
self.evaluator = pte.RelevanceEvaluator(qrels, [m for _, m in measure_map.values()], relevance_level=rel_level)
self.measure_map = measure_map

def iter_calc(self, run):
result = self.evaluator.evaluate(run)
for query_id, measures in result.items():
for measure_str, value in measures.items():
yield Metric(query_id=query_id, measure=self.measure_map[measure_str], value=value)
yield Metric(query_id=query_id, measure=self.measure_map[measure_str][0], value=value)


providers.register(PytrecEvalProvider())
2 changes: 1 addition & 1 deletion ir_measures/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def parse_trec_measure(measure: str) -> List['Measure']:
'map': (ir_measures.AP, None, None),
'G': (None, None, None),
'success': (ir_measures.Success, 'cutoff', [1, 5, 10]),
'set_F': (None, None, None),
'set_F': (ir_measures.SetF, 'beta', [1.]),
'iprec_at_recall': (ir_measures.IPrec, 'recall', [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]),
}
import pytrec_eval
Expand Down
63 changes: 63 additions & 0 deletions test/test_pytrec_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,69 @@ def test_SetR(self):
self.assertEqual(result[1].value, 0)
self.assertEqual(provider.calc_aggregate([measure], qrels, run)[measure], 0)

def test_SetF(self):
qrels = list(ir_measures.read_trec_qrels('''
0 0 D0 0
0 0 D1 1
0 0 D2 2
0 0 D3 2
0 0 D4 0
1 0 D0 1
1 0 D3 2
1 0 D5 0
'''))
run = list(ir_measures.read_trec_run('''
0 0 D0 1 0.8 run
0 0 D2 2 0.7 run
0 0 D1 3 0.3 run
0 0 D3 4 0.4 run
0 0 D4 5 0.1 run
1 0 D1 1 0.8 run
1 0 D4 2 0.7 run
1 0 D3 3 0.3 run
1 0 D2 4 0.4 run
'''))
provider = ir_measures.providers.PytrecEvalProvider()
measure = ir_measures.SetF(rel=1)
result = list(provider.iter_calc([measure], qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertAlmostEqual(result[0].value, 0.75, places=4)
self.assertEqual(result[1].query_id, "1")
self.assertAlmostEqual(result[1].value, .33333, places=4)
self.assertAlmostEqual(provider.calc_aggregate([measure], qrels, run)[measure], 0.5417, places=4)

measure = ir_measures.SetF(rel=1, beta=0.5)
result = list(provider.iter_calc([measure], qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertAlmostEqual(result[0].value, 0.6923, places=4)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0.3)
self.assertAlmostEqual(provider.calc_aggregate([measure], qrels, run)[measure], 0.49615, places=4)

measure = ir_measures.SetF(rel=1, beta=2.0)
result = list(provider.iter_calc([measure], qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertAlmostEqual(result[0].value, 0.81818, places=4)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0.375)
self.assertAlmostEqual(provider.calc_aggregate([measure], qrels, run)[measure], 0.59659, places=4)

measure = ir_measures.SetF(rel=3)
result = list(provider.iter_calc([measure], qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 0)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0)
self.assertEqual(provider.calc_aggregate([measure], qrels, run)[measure], 0)

# make sure the multiple invocations hapen correctly
res = provider.calc_aggregate([ir_measures.SetF(rel=1), ir_measures.SetF(rel=1, beta=0.5), ir_measures.SetF(rel=1, beta=2.0), ir_measures.SetF(rel=3)], qrels, run)
self.assertAlmostEqual(res[ir_measures.SetF(rel=1)], 0.5417, places=4)
self.assertAlmostEqual(res[ir_measures.SetF(rel=1, beta=0.5)], 0.49615, places=4)
self.assertAlmostEqual(res[ir_measures.SetF(rel=1, beta=2.0)], 0.59659, places=4)
self.assertEqual(res[ir_measures.SetF(rel=3)], 0)


def test_IPrec(self):
qrels = list(ir_measures.read_trec_qrels('''
0 0 D0 1
Expand Down
2 changes: 2 additions & 0 deletions test/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def test_parse_trec_measure(self):
'ndcg_cut': [nDCG@5, nDCG@10, nDCG@15, nDCG@20, nDCG@30, nDCG@100, nDCG@200, nDCG@500, nDCG@1000],
'set_P': [SetP],
'set_recall': [SetR],
'set_F': [SetF],
'set_F.1.0,0.5,2.4': [SetF, SetF(beta=0.5), SetF(beta=2.4)],
'official': [P@5, P@10, P@15, P@20, P@30, P@100, P@200, P@500, P@1000, Rprec, Bpref, IPrec@0.0, IPrec@0.1, IPrec@0.2, IPrec@0.3, IPrec@0.4, IPrec@0.5, IPrec@0.6, IPrec@0.7, IPrec@0.8, IPrec@0.9, IPrec@1.0, AP, NumQ, NumRel, NumRelRet, NumRet, RR],
}
for case in cases:
Expand Down

0 comments on commit 3e96cac

Please sign in to comment.