Skip to content

Commit

Permalink
Merge pull request #1274 from scikit-hep/ikrommyd/error-if-name-exist…
Browse files Browse the repository at this point in the history
…s-in-analysis-tools

fix: error if `name` already exists in `analysis_tools`'s `Weights` and `PackedSelection`
  • Loading branch information
lgray authored Feb 14, 2025
2 parents 3d051d9 + c1cf851 commit 2668081
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
11 changes: 11 additions & 0 deletions src/coffea/analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, size, storeIndividual=False):
self._modifiers = {}
self._weightStats = {}
self._storeIndividual = storeIndividual
self._names = []

@property
def weightStatistics(self):
Expand All @@ -127,6 +128,7 @@ def __add_eager(self, name, weight, weightUp, weightDown, shift):
weight.max(),
weight.size,
)
self._names.append(name)

def __add_delayed(self, name, weight, weightUp, weightDown, shift):
"""Add a new weight with delayed calculation"""
Expand All @@ -148,6 +150,7 @@ def __add_delayed(self, name, weight, weightUp, weightDown, shift):
"minw": dask_awkward.min(weight),
"maxw": dask_awkward.max(weight),
}
self._names.append(name)

def add(self, name, weight, weightUp=None, weightDown=None, shift=False):
"""Add a new weight
Expand All @@ -173,6 +176,8 @@ def add(self, name, weight, weightUp=None, weightDown=None, shift=False):
.. note:: ``weightUp`` and ``weightDown`` are assumed to be rvalue-like and may be modified in-place by this function
"""
if name in self._names:
raise ValueError(f"Weight '{name}' already exists")
if name.endswith("Up") or name.endswith("Down"):
raise ValueError(
"Avoid using 'Up' and 'Down' in weight names, instead pass appropriate shifts to add() call"
Expand Down Expand Up @@ -223,6 +228,7 @@ def __add_multivariation_eager(
weight.max(),
weight.size,
)
self._names.append(name)

def __add_multivariation_delayed(
self, name, weight, modifierNames, weightsUp, weightsDown, shift=False
Expand Down Expand Up @@ -258,6 +264,7 @@ def __add_multivariation_delayed(
"minw": dask_awkward.min(weight),
"maxw": dask_awkward.max(weight),
}
self._names.append(name)

def add_multivariation(
self, name, weight, modifierNames, weightsUp, weightsDown, shift=False
Expand Down Expand Up @@ -287,6 +294,8 @@ def add_multivariation(
.. note:: ``weightUp`` and ``weightDown`` are assumed to be rvalue-like and may be modified in-place by this function
"""
if name in self._names:
raise ValueError(f"Weight '{name}' already exists")
if name.endswith("Up") or name.endswith("Down"):
raise ValueError(
"Avoid using 'Up' and 'Down' in weight names, instead pass appropriate shifts to add() call"
Expand Down Expand Up @@ -1234,6 +1243,8 @@ def add(self, name, selection, fill_value=False):
fill_value : bool, optional
All masked entries will be filled as specified (default: ``False``)
"""
if name in self._names:
raise ValueError(f"Selection '{name}' already exists")
if isinstance(selection, dask.array.Array):
raise ValueError(
"Dask arrays are not supported, please convert them to dask_awkward.Array by using dask_awkward.from_dask_array()"
Expand Down
43 changes: 41 additions & 2 deletions tests/test_analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def test_weights():
shift=True,
)

with pytest.raises(ValueError, match="Weight 'test' already exists"):
weight.add("test", scale_central, weightUp=scale_up, weightDown=scale_down)

var_names = weight.variations
expected_names = ["testShiftUp", "testShiftDown", "testUp", "testDown"]
for name in expected_names:
Expand Down Expand Up @@ -105,6 +108,9 @@ def test_weights_dak(optimization_enabled):
shift=True,
)

with pytest.raises(ValueError, match="Weight 'test' already exists"):
weight.add("test", scale_central, weightUp=scale_up, weightDown=scale_down)

var_names = weight.variations
expected_names = ["testShiftUp", "testShiftDown", "testUp", "testDown"]
for name in expected_names:
Expand Down Expand Up @@ -153,6 +159,15 @@ def test_weights_multivariation():
weightsDown=[scale_down, scale_down_2],
)

with pytest.raises(ValueError, match="Weight 'test' already exists"):
weight.add_multivariation(
"test",
scale_central,
modifierNames=["A", "B"],
weightsUp=[scale_up, scale_up_2],
weightsDown=[scale_down, scale_down_2],
)

var_names = weight.variations
expected_names = ["test_AUp", "test_ADown", "test_BUp", "test_BDown"]
for name in expected_names:
Expand Down Expand Up @@ -211,6 +226,15 @@ def test_weights_multivariation_dak(optimization_enabled):
weightsDown=[scale_down, scale_down_2],
)

with pytest.raises(ValueError, match="Weight 'test' already exists"):
weight.add_multivariation(
"test",
scale_central,
modifierNames=["A", "B"],
weightsUp=[scale_up, scale_up_2],
weightsDown=[scale_down, scale_down_2],
)

var_names = weight.variations
expected_names = ["test_AUp", "test_ADown", "test_BUp", "test_BDown"]
for name in expected_names:
Expand Down Expand Up @@ -253,6 +277,11 @@ def test_weights_partial():
weights.add("w1", w1)
weights.add("w2", w2)

with pytest.raises(ValueError, match="Weight 'w1' already exists"):
weights.add("w1", w1)
with pytest.raises(ValueError, match="Weight 'w2' already exists"):
weights.add("w2", w2)

test_exclude_none = weights.weight()
assert np.all(np.abs(test_exclude_none - w1 * w2) < 1e-6)

Expand Down Expand Up @@ -321,6 +350,11 @@ def test_weights_partial_dak(optimization_enabled):
weights.add("w1", w1)
weights.add("w2", w2)

with pytest.raises(ValueError, match="Weight 'w1' already exists"):
weights.add("w1", w1)
with pytest.raises(ValueError, match="Weight 'w2' already exists"):
weights.add("w2", w2)

test_exclude_none = weights.weight()
assert np.all(np.abs(test_exclude_none - w1 * w2).compute() < 1e-6)

Expand Down Expand Up @@ -397,6 +431,11 @@ def test_packed_selection_basic(dtype):
sel.add("fizz", fizz)
sel.add("buzz", buzz)

with pytest.raises(ValueError, match="Selection 'fizz' already exists"):
sel.add("fizz", fizz)
with pytest.raises(ValueError, match="Selection 'buzz' already exists"):
sel.add("buzz", buzz)

assert np.all(
sel.all()
== np.array(
Expand Down Expand Up @@ -449,7 +488,7 @@ def test_packed_selection_basic(dtype):
with pytest.raises(RuntimeError):
overpack = PackedSelection(dtype=dtype)
for i in range(65):
overpack.add("sel_%d", all_true)
overpack.add(f"sel_{i}", all_true)

with pytest.raises(
ValueError,
Expand Down Expand Up @@ -787,7 +826,7 @@ def test_packed_selection_basic_dak(optimization_enabled, dtype):
with pytest.raises(RuntimeError):
overpack = PackedSelection(dtype=dtype)
for i in range(65):
overpack.add("sel_%d", all_true)
overpack.add(f"sel_{i}", all_true)

with pytest.raises(
ValueError,
Expand Down

0 comments on commit 2668081

Please sign in to comment.