Skip to content

Commit b92cb04

Browse files
committed
fix: some cleanup and a mistake fixed that Copilot noticed
Signed-off-by: Henry Schreiner <[email protected]>
1 parent 17d8137 commit b92cb04

File tree

4 files changed

+71
-56
lines changed

4 files changed

+71
-56
lines changed

src/boost_histogram/serialization/_axis.py

Lines changed: 37 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from __future__ import annotations
22

33
import functools
4-
from collections.abc import Generator
54
from typing import Any
65

76
from .. import axis
87

9-
__all__ = ["_axes_from_dict", "_axis_to_dict"]
8+
__all__ = ["_axis_from_dict", "_axis_to_dict"]
109

1110

1211
def __dir__() -> list[str]:
@@ -94,43 +93,39 @@ def _(ax: axis.Boolean, /) -> dict[str, Any]:
9493
return data
9594

9695

97-
def _axes_from_dict(
98-
data_list: list[dict[str, Any]], /
99-
) -> Generator[axis.Axis, None, None]:
100-
for data in data_list:
101-
hist_type = data["type"]
102-
opts = {"metadata": data["metadata"]} if "metadata" in data else {}
103-
if hist_type == "regular":
104-
yield axis.Regular(
105-
data["bins"],
106-
data["lower"],
107-
data["upper"],
108-
underflow=data["underflow"],
109-
overflow=data["overflow"],
110-
circular=data["circular"],
111-
**opts,
112-
)
113-
elif hist_type == "variable":
114-
yield axis.Variable(
115-
data["edges"],
116-
underflow=data["underflow"],
117-
overflow=data["overflow"],
118-
circular=data["circular"],
119-
**opts,
120-
)
121-
elif hist_type == "category_int":
122-
yield axis.IntCategory(
123-
data["categories"],
124-
overflow=data["flow"],
125-
**opts,
126-
)
127-
elif hist_type == "category_str":
128-
yield axis.StrCategory(
129-
data["categories"],
130-
overflow=data["flow"],
131-
**opts,
132-
)
133-
elif hist_type == "boolean":
134-
yield axis.Boolean(**opts)
135-
else:
136-
raise TypeError(f"Unsupported axis type: {hist_type}")
96+
def _axis_from_dict(data: dict[str, Any], /) -> axis.Axis:
97+
hist_type = data["type"]
98+
if hist_type == "regular":
99+
return axis.Regular(
100+
data["bins"],
101+
data["lower"],
102+
data["upper"],
103+
underflow=data["underflow"],
104+
overflow=data["overflow"],
105+
circular=data["circular"],
106+
metadata=data.get("metadata"),
107+
)
108+
if hist_type == "variable":
109+
return axis.Variable(
110+
data["edges"],
111+
underflow=data["underflow"],
112+
overflow=data["overflow"],
113+
circular=data["circular"],
114+
metadata=data.get("metadata"),
115+
)
116+
if hist_type == "category_int":
117+
return axis.IntCategory(
118+
data["categories"],
119+
overflow=data["flow"],
120+
metadata=data.get("metadata"),
121+
)
122+
if hist_type == "category_str":
123+
return axis.StrCategory(
124+
data["categories"],
125+
overflow=data["flow"],
126+
metadata=data.get("metadata"),
127+
)
128+
if hist_type == "boolean":
129+
return axis.Boolean(metadata=data.get("metadata"))
130+
131+
raise TypeError(f"Unsupported axis type: {hist_type}")

src/boost_histogram/serialization/generic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any
44

55
from .. import Histogram
6-
from ._axis import _axes_from_dict, _axis_to_dict
6+
from ._axis import _axis_from_dict, _axis_to_dict
77
from ._storage import _data_from_dict, _storage_from_dict, _storage_to_dict
88

99
__all__ = ["from_dict", "to_dict"]
@@ -31,7 +31,7 @@ def from_dict(data: dict[str, Any], /) -> Histogram:
3131
"""Convert a dictionary to an Histogram."""
3232

3333
h = Histogram(
34-
*_axes_from_dict(data["axes"]),
34+
*(_axis_from_dict(ax) for ax in data["axes"]),
3535
storage=_storage_from_dict(data["storage"]),
3636
metadata=data.get("metadata"),
3737
)

src/boost_histogram/serialization/hdf5.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ def __dir__() -> list[str]:
1616

1717

1818
def write_hdf5_schema(grp: h5py.Group, /, histogram: Histogram) -> None:
19+
"""
20+
Write a histogram to an HDF5 group.
21+
"""
1922
hist_dict = to_dict(histogram)
2023

2124
# All referenced objects will be stored inside of /{name}/ref_axes
@@ -43,9 +46,9 @@ def write_hdf5_schema(grp: h5py.Group, /, histogram: Histogram) -> None:
4346
for key, value in ax_info.items():
4447
ax_group.attrs[key] = value
4548
if ax_metadata is not None:
46-
ax_metadata = ax_group.create_group("metadata")
47-
for k, v in value.items():
48-
ax_metadata.attrs[k] = v
49+
ax_metadata_grp = ax_group.create_group("metadata")
50+
for k, v in ax_metadata.items():
51+
ax_metadata_grp.attrs[k] = v
4952
if ax_edges is not None:
5053
ax_group.create_dataset("edges", shape=ax_edges.shape, data=ax_edges)
5154
if ax_cats is not None:
@@ -68,7 +71,9 @@ def write_hdf5_schema(grp: h5py.Group, /, histogram: Histogram) -> None:
6871

6972

7073
def read_hdf5_schema(grp: h5py.Group, /) -> Histogram:
71-
axes: list[dict[str, Any]] = []
74+
"""
75+
Read a histogram from an HDF5 group.
76+
"""
7277
axes_grp = grp["axes"]
7378
axes_ref = grp["ref_axes"]
7479
assert isinstance(axes_ref, h5py.Group)

tests/test_serialization_generic.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
[
1313
pytest.param(bh.storage.AtomicInt64(), "int", id="atomic_int"),
1414
pytest.param(bh.storage.Int64(), "int", id="int"),
15-
pytest.param(bh.storage.Unlimited(), "double", id="unlimited"), # This always renders as double
15+
pytest.param(
16+
bh.storage.Unlimited(), "double", id="unlimited"
17+
), # This always renders as double
1618
pytest.param(bh.storage.Double(), "double", id="double"),
17-
])
19+
],
20+
)
1821
def test_simple_to_dict(storage_type: bh.storage.Storage, expected_type: str) -> None:
1922
h = bh.Histogram(
2023
bh.axis.Regular(10, 0, 1),
@@ -33,6 +36,7 @@ def test_simple_to_dict(storage_type: bh.storage.Storage, expected_type: str) ->
3336
assert data["storage"]["type"] == expected_type
3437
assert data["storage"]["data"] == pytest.approx(np.zeros(12))
3538

39+
3640
def test_weighed_to_dict() -> None:
3741
h = bh.Histogram(
3842
bh.axis.Integer(3, 15),
@@ -51,11 +55,12 @@ def test_weighed_to_dict() -> None:
5155
assert data["storage"]["data"]["values"] == pytest.approx(np.zeros(14))
5256
assert data["storage"]["data"]["variances"] == pytest.approx(np.zeros(14))
5357

58+
5459
def test_mean_to_dict() -> None:
5560
h = bh.Histogram(
5661
bh.axis.StrCategory(["one", "two", "three"]),
5762
storage=bh.storage.Mean(),
58-
metadata = {"name": "hi"},
63+
metadata={"name": "hi"},
5964
)
6065
data = generic.to_dict(h)
6166

@@ -68,6 +73,7 @@ def test_mean_to_dict() -> None:
6873
assert data["storage"]["data"]["values"] == pytest.approx(np.zeros(4))
6974
assert data["storage"]["data"]["variances"] == pytest.approx(np.zeros(4))
7075

76+
7177
def test_weighted_mean_to_dict() -> None:
7278
h = bh.Histogram(
7379
bh.axis.IntCategory([1, 2, 3]),
@@ -81,9 +87,15 @@ def test_weighted_mean_to_dict() -> None:
8187
assert data["axes"][0]["categories"] == pytest.approx([1, 2, 3])
8288
assert data["axes"][0]["flow"]
8389
assert data["storage"]["type"] == "weighted_mean"
84-
assert data["storage"]["data"]["sum_of_weights"] == pytest.approx(np.array([20, 40, 60, 10]))
85-
assert data["storage"]["data"]["sum_of_weights_squared"] == pytest.approx(np.array([200, 800, 1800, 50]))
86-
assert data["storage"]["data"]["values"] == pytest.approx(np.array([100, 200, 300, 1]))
90+
assert data["storage"]["data"]["sum_of_weights"] == pytest.approx(
91+
np.array([20, 40, 60, 10])
92+
)
93+
assert data["storage"]["data"]["sum_of_weights_squared"] == pytest.approx(
94+
np.array([200, 800, 1800, 50])
95+
)
96+
assert data["storage"]["data"]["values"] == pytest.approx(
97+
np.array([100, 200, 300, 1])
98+
)
8799
assert data["storage"]["data"]["variances"] == pytest.approx(np.zeros(4))
88100

89101

@@ -94,7 +106,8 @@ def test_weighted_mean_to_dict() -> None:
94106
pytest.param(bh.storage.Int64(), id="int"),
95107
pytest.param(bh.storage.Double(), id="double"),
96108
pytest.param(bh.storage.Unlimited(), id="unlimited"),
97-
])
109+
],
110+
)
98111
def test_round_trip_simple(storage_type: bh.storage.Storage) -> None:
99112
h = bh.Histogram(
100113
bh.axis.Regular(10, 0, 10),
@@ -126,6 +139,7 @@ def test_round_trip_weighted() -> None:
126139
assert pytest.approx(np.array(h.axes[0])) == np.array(h2.axes[0])
127140
assert np.asarray(h) == pytest.approx(h2)
128141

142+
129143
def test_round_trip_mean() -> None:
130144
h = bh.Histogram(
131145
bh.axis.StrCategory(["1", "2", "3"]),
@@ -139,6 +153,7 @@ def test_round_trip_mean() -> None:
139153
assert pytest.approx(np.array(h.axes[0])) == np.array(h2.axes[0])
140154
assert np.asarray(h) == pytest.approx(h2)
141155

156+
142157
def test_round_trip_weighted_mean() -> None:
143158
h = bh.Histogram(
144159
bh.axis.IntCategory([1, 2, 3]),
@@ -150,4 +165,4 @@ def test_round_trip_weighted_mean() -> None:
150165
h2 = generic.from_dict(data)
151166

152167
assert pytest.approx(np.array(h.axes[0])) == np.array(h2.axes[0])
153-
assert np.asarray(h) == pytest.approx(h2)
168+
assert np.asarray(h) == pytest.approx(h2)

0 commit comments

Comments
 (0)