Skip to content

Commit f67a9a2

Browse files
committed
fix: support transform axes, more tests
Signed-off-by: Henry Schreiner <[email protected]>
1 parent b92cb04 commit f67a9a2

File tree

4 files changed

+101
-15
lines changed

4 files changed

+101
-15
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ messages_control.disable = [
227227
"too-many-statements",
228228
"too-many-positional-arguments",
229229
"wrong-import-position",
230+
"unused-argument", # Covered by ruff
231+
"unsubscriptable-object", # Wrongly triggered
230232
]
231233

232234
[tool.ruff.lint]

src/boost_histogram/serialization/_axis.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,26 @@ def _axis_to_dict(ax: Any, /) -> dict[str, Any]:
2222
@_axis_to_dict.register(axis.Integer)
2323
def _(ax: axis.Regular | axis.Integer, /) -> dict[str, Any]:
2424
"""Convert a Regular axis to a dictionary."""
25-
data = {
26-
"type": "regular",
27-
"lower": ax.edges[0],
28-
"upper": ax.edges[-1],
29-
"bins": ax.size,
30-
"underflow": ax.traits.underflow,
31-
"overflow": ax.traits.overflow,
32-
"circular": ax.traits.circular,
33-
}
25+
26+
# Special handling if the axis has a transform
27+
if isinstance(ax, axis.Regular) and ax.transform is not None:
28+
data = {
29+
"type": "variable",
30+
"edges": ax.edges,
31+
"underflow": ax.traits.underflow,
32+
"overflow": ax.traits.overflow,
33+
"circular": ax.traits.circular,
34+
}
35+
else:
36+
data = {
37+
"type": "regular",
38+
"lower": ax.edges[0],
39+
"upper": ax.edges[-1],
40+
"bins": ax.size,
41+
"underflow": ax.traits.underflow,
42+
"overflow": ax.traits.overflow,
43+
"circular": ax.traits.circular,
44+
}
3445
if ax.metadata is not None:
3546
data["metadata"] = ax.metadata
3647

tests/test_hdf5.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,25 @@
1515
@pytest.mark.parametrize(
1616
("storage_type", "fill_args", "fill_kwargs"),
1717
[
18-
pytest.param(bh.storage.Weight(), [0.3, 0.3, 0.4, 1.2], {}, id="weight"),
18+
pytest.param(bh.storage.Double(), [0.3, 0.3, 0.4, 1.2], {}, id="double"),
19+
pytest.param(bh.storage.Int64(), [0.3, 0.3, 0.4, 1.2], {}, id="int64"),
1920
pytest.param(
20-
bh.storage.WeightedMean(),
21-
[0.3, 0.3, 0.4, 1.2, 1.6],
22-
{"sample": [1, 2, 3, 4, 4], "weight": [1, 1, 1, 1, 2]},
23-
id="weighted_mean",
21+
bh.storage.AtomicInt64(), [0.3, 0.3, 0.4, 1.2], {}, id="atomicint"
2422
),
23+
pytest.param(bh.storage.Unlimited(), [0.3, 0.3, 0.4, 1.2], {}, id="unlimited"),
24+
pytest.param(bh.storage.Weight(), [0.3, 0.3, 0.4, 1.2], {}, id="weight"),
2525
pytest.param(
2626
bh.storage.Mean(),
2727
[0.3, 0.3, 0.4, 1.2, 1.6],
2828
{"sample": [1, 2, 3, 4, 4]},
2929
id="mean",
3030
),
31+
pytest.param(
32+
bh.storage.WeightedMean(),
33+
[0.3, 0.3, 0.4, 1.2, 1.6],
34+
{"sample": [1, 2, 3, 4, 4], "weight": [1, 1, 1, 1, 2]},
35+
id="weighted_mean",
36+
),
3137
],
3238
)
3339
def test_hdf5_storage(
@@ -51,7 +57,14 @@ def test_hdf5_storage(
5157

5258
# checking types of the reconstructed axes
5359
assert type(actual_hist.axes[0]) is type(re_constructed_hist.axes[0])
54-
assert actual_hist.storage_type == re_constructed_hist.storage_type
60+
61+
if isinstance(storage_type, bh.storage.Unlimited):
62+
actual_hist_storage = bh.storage.Double()
63+
elif isinstance(storage_type, bh.storage.AtomicInt64):
64+
actual_hist_storage = bh.storage.Int64()
65+
else:
66+
actual_hist_storage = storage_type
67+
assert isinstance(actual_hist_storage, re_constructed_hist.storage_type)
5568

5669
# checking values of the essential inputs of the axes
5770
assert actual_hist.axes[0].traits == re_constructed_hist.axes[0].traits
@@ -76,3 +89,43 @@ def test_hdf5_storage(
7689
assert actual_hist.counts() == pytest.approx(
7790
re_constructed_hist.counts(), abs=1e-4, rel=1e-9
7891
)
92+
93+
94+
def test_hdf5_2d(tmp_path: Path) -> None:
95+
h = bh.Histogram(bh.axis.Integer(0, 4), bh.axis.StrCategory(["a", "b", "c"]))
96+
h.fill([0, 1, 1, 1], ["a", "b", "b", "c"])
97+
98+
filepath = tmp_path / "hist.h5"
99+
with h5py.File(filepath, "x") as f:
100+
grp = f.create_group("test_hist")
101+
s.write_hdf5_schema(grp, h)
102+
103+
with h5py.File(filepath) as f:
104+
re_constructed_hist = s.read_hdf5_schema(f["test_hist"])
105+
106+
actual_hist = h.copy()
107+
108+
assert isinstance(re_constructed_hist.axes[0], bh.axis.Regular)
109+
assert type(actual_hist.axes[1]) is type(re_constructed_hist.axes[1])
110+
111+
assert (
112+
actual_hist.axes[0].traits.underflow
113+
== re_constructed_hist.axes[0].traits.underflow
114+
)
115+
assert (
116+
actual_hist.axes[0].traits.overflow
117+
== re_constructed_hist.axes[0].traits.overflow
118+
)
119+
assert (
120+
actual_hist.axes[0].traits.circular
121+
== re_constructed_hist.axes[0].traits.circular
122+
)
123+
assert actual_hist.axes[1].traits == re_constructed_hist.axes[1].traits
124+
125+
assert np.asarray(actual_hist.axes[0].edges) == pytest.approx(
126+
np.asarray(re_constructed_hist.axes[0].edges)
127+
)
128+
assert list(actual_hist.axes[1]) == list(re_constructed_hist.axes[1])
129+
130+
assert h.values() == pytest.approx(re_constructed_hist.values())
131+
assert h.values(flow=True) == pytest.approx(re_constructed_hist.values(flow=True))

tests/test_serialization_generic.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,26 @@ def test_weighted_mean_to_dict() -> None:
9999
assert data["storage"]["data"]["variances"] == pytest.approx(np.zeros(4))
100100

101101

102+
def test_transform_log_axis_to_dict() -> None:
103+
h = bh.Histogram(bh.axis.Regular(10, 1, 10, transform=bh.axis.transform.log))
104+
data = generic.to_dict(h)
105+
106+
assert data["axes"][0]["type"] == "variable"
107+
assert data["axes"][0]["edges"] == pytest.approx(
108+
np.exp(np.linspace(0, np.log(10), 11))
109+
)
110+
111+
112+
def test_transform_sqrt_axis_to_dict() -> None:
113+
h = bh.Histogram(bh.axis.Regular(10, 0, 10, transform=bh.axis.transform.sqrt))
114+
data = generic.to_dict(h)
115+
116+
assert data["axes"][0]["type"] == "variable"
117+
assert data["axes"][0]["edges"] == pytest.approx(
118+
(np.linspace(0, np.sqrt(10), 11)) ** 2
119+
)
120+
121+
102122
@pytest.mark.parametrize(
103123
"storage_type",
104124
[

0 commit comments

Comments
 (0)