Skip to content

Commit

Permalink
Refactor store_model_to_hdf to
Browse files Browse the repository at this point in the history
store_simulation_state_to_hdf
  • Loading branch information
wkerzendorf committed Dec 10, 2023
1 parent a949f49 commit 56a948d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
18 changes: 10 additions & 8 deletions tardis/io/model/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import h5py


def store_model_to_hdf(model, fname):
def store_simulation_state_to_hdf(simulation_state, fname):
"""
Stores data from SimulationState object into a hdf file.
Expand All @@ -14,14 +14,16 @@ def store_model_to_hdf(model, fname):
filename : str
"""
with h5py.File(fname, "a") as f:
model_group = f.require_group("model")
model_group.clear()
simulation_state_group = f.require_group("simulation_state")
simulation_state_group.clear()

model_dict = simulation_state_to_dict(model)
simulation_state_dict = simulation_state_to_dict(simulation_state)

for key, value in model_dict.items():
for key, value in simulation_state_dict.items():
if key.endswith("_cgs"):
model_group.create_dataset(key, data=value[0])
model_group.create_dataset(key + "_unit", data=value[1])
simulation_state_group.create_dataset(key, data=value[0])
simulation_state_group.create_dataset(
key + "_unit", data=value[1]
)
else:
model_group.create_dataset(key, data=value)
simulation_state_group.create_dataset(key, data=value)
28 changes: 18 additions & 10 deletions tardis/io/tests/test_model_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from astropy import units as u

from tardis.io.configuration.config_reader import Configuration
from tardis.io.model.hdf import store_model_to_hdf
from tardis.io.model.hdf import store_simulation_state_to_hdf
from tardis.io.model.model_reader import (
simulation_state_to_dict,
store_transport_to_hdf,
Expand Down Expand Up @@ -212,29 +212,35 @@ def test_model_to_dict(simulation_verysimple):
def test_store_model_to_hdf(simulation_verysimple, tmp_path):
simulation_state = simulation_verysimple.simulation_state

fname = tmp_path / "model.h5"
fname = tmp_path / "simulation_state.h5"

# Store model object
store_model_to_hdf(simulation_state, fname)
store_simulation_state_to_hdf(simulation_state, fname)

# Check file contents
with h5py.File(fname) as f:
assert np.array_equal(
f["simulation_state/velocity_cgs"], simulation_state.velocity.cgs.value
f["simulation_state/velocity_cgs"],
simulation_state.velocity.cgs.value,
)
assert np.array_equal(
f["simulation_state/abundance"], simulation_state.abundance
)
assert np.array_equal(f["simulation_state/abundance"], simulation_state.abundance)
assert np.array_equal(
f["simulation_state/time_explosion_cgs"],
simulation_state.time_explosion.cgs.value,
)
assert np.array_equal(
f["simulation_state/t_inner_cgs"], simulation_state.t_inner.cgs.value
f["simulation_state/t_inner_cgs"],
simulation_state.t_inner.cgs.value,
)
assert np.array_equal(
f["simulation_state/t_radiative_cgs"], simulation_state.t_radiative.cgs.value
f["simulation_state/t_radiative_cgs"],
simulation_state.t_radiative.cgs.value,
)
assert np.array_equal(
f["simulation_state/dilution_factor"], simulation_state.dilution_factor
f["simulation_state/dilution_factor"],
simulation_state.dilution_factor,
)
assert np.array_equal(
f["simulation_state/v_boundary_inner_cgs"],
Expand All @@ -245,10 +251,12 @@ def test_store_model_to_hdf(simulation_verysimple, tmp_path):
simulation_state.v_boundary_outer.cgs.value,
)
assert np.array_equal(
f["simulation_state/r_inner_cgs"], simulation_state.r_inner.cgs.value
f["simulation_state/r_inner_cgs"],
simulation_state.r_inner.cgs.value,
)
assert np.array_equal(
f["simulation_state/density_cgs"], simulation_state.density.cgs.value
f["simulation_state/density_cgs"],
simulation_state.density.cgs.value,
)


Expand Down

0 comments on commit 56a948d

Please sign in to comment.