Skip to content

Commit 8833767

Browse files
authored
Merge pull request #107 from NREL/feature/serde-api-features
Feature/serde api features
2 parents f9a1d7e + 47df893 commit 8833767

28 files changed

+366
-84
lines changed

.github/workflows/py-tests.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
strategy:
2727
fail-fast: false
2828
matrix:
29-
python-version: ['3.9', '3.10', '3.11']
29+
python-version: ['3.10', '3.11']
3030

3131
env:
3232
PYTHON: ${{ matrix.python-version }}

.github/workflows/wheels.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ jobs:
2121
- macos
2222
- windows
2323
python-version:
24-
- "9"
2524
- "10"
2625
- "11"
2726
include:
@@ -36,7 +35,7 @@ jobs:
3635
- name: set up python
3736
uses: actions/setup-python@v4
3837
with:
39-
python-version: "3.9"
38+
python-version: "3.11"
4039

4140
- name: set up rust
4241
if: matrix.os != 'ubuntu'

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ authors = [
2424
description = "Tool for modeling and optimization of advanced locomotive powertrains for freight rail decarbonization."
2525
readme = "README.md"
2626
license = { file = "LICENSE.md" }
27-
requires-python = ">=3.9, <3.12"
27+
requires-python = ">=3.10, <3.12"
2828
classifiers = [
2929
"Programming Language :: Python :: 3",
3030
"License :: OSI Approved :: BSD License",
@@ -47,6 +47,7 @@ dependencies = [
4747
"pyarrow",
4848
"requests",
4949
"PyYAML==6.0.2",
50+
"msgpack==1.1.0",
5051
]
5152

5253
[project.urls]

python/altrios/__init__.py

+62-13
Original file line numberDiff line numberDiff line change
@@ -119,27 +119,76 @@ def history_path_list(self, element_as_list:bool=False) -> List[str]:
119119
item for item in self.variable_path_list(
120120
element_as_list=element_as_list) if "history" in item_str(item)
121121
]
122-
return history_path_list
123-
124-
def to_pydict(self) -> Dict:
122+
return history_path_list
123+
124+
# TODO connect to crate features
125+
data_formats = [
126+
'yaml',
127+
'msg_pack',
128+
# 'toml',
129+
'json',
130+
]
131+
132+
def to_pydict(self, data_fmt: str = "msg_pack", flatten: bool = False) -> Dict:
125133
"""
126134
Returns self converted to pure python dictionary with no nested Rust objects
135+
# Arguments
136+
- `flatten`: if True, returns dict without any hierarchy
137+
- `data_fmt`: data format for intermediate conversion step
127138
"""
128-
from yaml import load
129-
try:
130-
from yaml import CLoader as Loader
131-
except ImportError:
132-
from yaml import Loader
133-
pydict = load(self.to_yaml(), Loader = Loader)
134-
return pydict
139+
data_fmt = data_fmt.lower()
140+
assert data_fmt in data_formats, f"`data_fmt` must be one of {data_formats}"
141+
match data_fmt:
142+
case "msg_pack":
143+
import msgpack
144+
pydict = msgpack.loads(self.to_msg_pack())
145+
case "yaml":
146+
from yaml import load
147+
try:
148+
from yaml import CLoader as Loader
149+
except ImportError:
150+
from yaml import Loader
151+
pydict = load(self.to_yaml(), Loader=Loader)
152+
case "json":
153+
from json import loads
154+
pydict = loads(self.to_json())
155+
156+
if not flatten:
157+
return pydict
158+
else:
159+
return next(iter(pd.json_normalize(pydict, sep=".").to_dict(orient='records')))
135160

136161
@classmethod
137-
def from_pydict(cls, pydict: Dict) -> Self:
162+
def from_pydict(cls, pydict: Dict, data_fmt: str = "msg_pack", skip_init: bool = True) -> Self:
138163
"""
139164
Instantiates Self from pure python dictionary
165+
# Arguments
166+
- `pydict`: dictionary to be converted to ALTRIOS object
167+
- `data_fmt`: data format for intermediate conversion step
168+
- `skip_init`: passed to `SerdeAPI` methods to control whether initialization
169+
is skipped
140170
"""
141-
import yaml
142-
return cls.from_yaml(yaml.dump(pydict),skip_init=False)
171+
data_fmt = data_fmt.lower()
172+
assert data_fmt in data_formats, f"`data_fmt` must be one of {data_formats}"
173+
match data_fmt.lower():
174+
case "yaml":
175+
import yaml
176+
obj = cls.from_yaml(yaml.dump(pydict), skip_init=skip_init)
177+
case "msg_pack":
178+
import msgpack
179+
try:
180+
obj = cls.from_msg_pack(
181+
msgpack.packb(pydict), skip_init=skip_init)
182+
except Exception as err:
183+
print(
184+
f"{err}\nFalling back to YAML.")
185+
obj = cls.from_pydict(
186+
pydict, data_fmt="yaml", skip_init=skip_init)
187+
case "json":
188+
from json import dumps
189+
obj = cls.from_json(dumps(pydict), skip_init=skip_init)
190+
191+
return obj
143192

144193
def to_dataframe(self, pandas:bool=False) -> [pd.DataFrame, pl.DataFrame, pl.LazyFrame]:
145194
"""

python/altrios/altrios_pyo3.pyi

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@ class SerdeAPI(object):
1616
@classmethod
1717
def from_yaml(cls) -> Self: ...
1818
@classmethod
19-
def from_file(cls) -> Self: ...
19+
def from_file(cls, skip_init=False) -> Self: ...
2020
def to_file(self): ...
2121
def to_bincode(self) -> bytes: ...
2222
def to_json(self) -> str: ...
2323
def to_yaml(self) -> str: ...
24+
def to_pydict(self, data_fmt: str = "msg_pack", flatten: bool = False) -> Dict: ...
25+
@classmethod
26+
def from_pydict(cls, pydict: Dict, data_fmt: str = "msg_pack") -> Self:
2427

2528

2629
class Consist(SerdeAPI):

python/altrios/demos/sim_manager_demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
t0_import = time.perf_counter()
2424
t0_total = time.perf_counter()
2525

26-
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file)
26+
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False)
2727
for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')]
2828

2929
location_map = alt.import_locations(alt.resources_root() / "networks/default_locations.csv")

python/altrios/demos/version_migration_demo.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ def migrate_network() -> Tuple[alt.Network, alt.Network]:
1313
old_network_path = alt.resources_root() / "networks/Taconite_v0.1.6.yaml"
1414
new_network_path = alt.resources_root() / "networks/Taconite.yaml"
1515

16-
network_from_old = alt.Network.from_file(old_network_path)
17-
network_from_new = alt.Network.from_file(new_network_path)
16+
network_from_old = alt.Network.from_file(old_network_path, skip_init=False)
17+
network_from_new = alt.Network.from_file(new_network_path, skip_init=False)
1818

1919
# `network_from_old` could be used to overwrite the file in the new format with
2020
# ```

python/altrios/rollout.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ def simulate_prescribed_rollout(
5959
else:
6060
demand_paths.append(demand_file)
6161

62-
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file)
62+
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False)
6363
for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')]
6464

6565
location_map = alt.import_locations(
6666
str(alt.resources_root() / "networks/default_locations.csv")
6767
)
68-
network = alt.Network.from_file(network_filename_path)
68+
network = alt.Network.from_file(network_filename_path, skip_init=False)
6969
sim_days = defaults.SIMULATION_DAYS
7070
scenarios = []
7171
for idx, scenario_year in enumerate(years):

python/altrios/tests/test_serde.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import time
2+
import altrios as alt
3+
4+
SAVE_INTERVAL = 100
5+
def get_solved_speed_limit_train_sim():
6+
# Build the train config
7+
rail_vehicle_loaded = alt.RailVehicle.from_file(
8+
alt.resources_root() / "rolling_stock/Manifest_Loaded.yaml")
9+
rail_vehicle_empty = alt.RailVehicle.from_file(
10+
alt.resources_root() / "rolling_stock/Manifest_Empty.yaml")
11+
12+
# https://docs.rs/altrios-core/latest/altrios_core/train/struct.TrainConfig.html
13+
train_config = alt.TrainConfig(
14+
rail_vehicles=[rail_vehicle_loaded, rail_vehicle_empty],
15+
n_cars_by_type={
16+
"Manifest_Loaded": 50,
17+
"Manifest_Empty": 50,
18+
},
19+
train_length_meters=None,
20+
train_mass_kilograms=None,
21+
)
22+
23+
# Build the locomotive consist model
24+
# instantiate battery model
25+
# https://docs.rs/altrios-core/latest/altrios_core/consist/locomotive/powertrain/reversible_energy_storage/struct.ReversibleEnergyStorage.html#
26+
res = alt.ReversibleEnergyStorage.from_file(
27+
alt.resources_root() / "powertrains/reversible_energy_storages/Kokam_NMC_75Ah_flx_drive.yaml"
28+
)
29+
30+
edrv = alt.ElectricDrivetrain(
31+
pwr_out_frac_interp=[0., 1.],
32+
eta_interp=[0.98, 0.98],
33+
pwr_out_max_watts=5e9,
34+
save_interval=SAVE_INTERVAL,
35+
)
36+
37+
bel: alt.Locomotive = alt.Locomotive.build_battery_electric_loco(
38+
reversible_energy_storage=res,
39+
drivetrain=edrv,
40+
loco_params=alt.LocoParams.from_dict(dict(
41+
pwr_aux_offset_watts=8.55e3,
42+
pwr_aux_traction_coeff_ratio=540.e-6,
43+
force_max_newtons=667.2e3,
44+
)))
45+
46+
# construct a vector of one BEL and several conventional locomotives
47+
loco_vec = [bel.clone()] + [alt.Locomotive.default()] * 7
48+
# instantiate consist
49+
loco_con = alt.Consist(
50+
loco_vec
51+
)
52+
53+
# Instantiate the intermediate `TrainSimBuilder`
54+
tsb = alt.TrainSimBuilder(
55+
train_id="0",
56+
origin_id="A",
57+
destination_id="B",
58+
train_config=train_config,
59+
loco_con=loco_con,
60+
)
61+
62+
# Load the network and construct the timed link path through the network.
63+
network = alt.Network.from_file(
64+
alt.resources_root() / 'networks/simple_corridor_network.yaml')
65+
66+
location_map = alt.import_locations(
67+
alt.resources_root() / "networks/simple_corridor_locations.csv")
68+
train_sim: alt.SetSpeedTrainSim = tsb.make_speed_limit_train_sim(
69+
location_map=location_map,
70+
save_interval=1,
71+
)
72+
train_sim.set_save_interval(SAVE_INTERVAL)
73+
est_time_net, _consist = alt.make_est_times(train_sim, network)
74+
75+
timed_link_path = alt.run_dispatch(
76+
network,
77+
alt.SpeedLimitTrainSimVec([train_sim]),
78+
[est_time_net],
79+
False,
80+
False,
81+
)[0]
82+
83+
train_sim.walk_timed_path(
84+
network=network,
85+
timed_path=timed_link_path,
86+
)
87+
assert len(train_sim.history) > 1
88+
89+
return train_sim
90+
91+
92+
def test_pydict():
93+
ts = get_solved_speed_limit_train_sim()
94+
95+
t0 = time.perf_counter_ns()
96+
ts_dict_msg = ts.to_pydict(flatten=False, data_fmt="msg_pack")
97+
ts_msg = alt.SpeedLimitTrainSim.from_pydict(
98+
ts_dict_msg, data_fmt="msg_pack")
99+
t1 = time.perf_counter_ns()
100+
t_msg = t1 - t0
101+
print(f"\nElapsed time for MessagePack: {t_msg:.3e} ns ")
102+
103+
t0 = time.perf_counter_ns()
104+
ts_dict_yaml = ts.to_pydict(flatten=False, data_fmt="yaml")
105+
ts_yaml = alt.SpeedLimitTrainSim.from_pydict(ts_dict_yaml, data_fmt="yaml")
106+
t1 = time.perf_counter_ns()
107+
t_yaml = t1 - t0
108+
print(f"Elapsed time for YAML: {t_yaml:.3e} ns ")
109+
print(f"YAML time per MessagePack time: {(t_yaml / t_msg):.3e} ")
110+
111+
t0 = time.perf_counter_ns()
112+
ts_dict_json = ts.to_pydict(flatten=False, data_fmt="json")
113+
_ts_json = alt.SpeedLimitTrainSim.from_pydict(
114+
ts_dict_json, data_fmt="json")
115+
t1 = time.perf_counter_ns()
116+
t_json = t1 - t0
117+
print(f"Elapsed time for json: {t_json:.3e} ns ")
118+
print(f"JSON time per MessagePack time: {(t_json / t_msg):.3e} ")
119+
120+
# `to_pydict` is necessary because of some funkiness with direct equality comparison
121+
assert ts_msg.to_pydict() == ts.to_pydict()
122+
assert ts_yaml.to_pydict() == ts.to_pydict()
123+
124+
if __name__ == "__main__":
125+
test_pydict()

python/altrios/train_planner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@ def run_train_planner(
11291129

11301130
if __name__ == "__main__":
11311131

1132-
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file)
1132+
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False)
11331133
for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')]
11341134

11351135
location_map = alt.import_locations(

rust/Cargo.lock

+23
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/altrios-core/Cargo.toml

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ rust-version = { workspace = true }
1313
[dependencies]
1414
csv = "1.1.6"
1515
serde = { version = "1.0.136", features = ["derive"] }
16+
rmp-serde = { version = "1.3.0", optional = true }
1617
serde_yaml = "0.8.23"
1718
serde_json = "1.0"
1819
uom = { workspace = true, features = ["use_serde"] }
@@ -56,9 +57,13 @@ tempfile = "3.10.1"
5657
derive_more = { version = "1.0.0", features = ["from_str", "from", "is_variant", "try_into"] }
5758

5859
[features]
59-
default = []
60+
default = ["serde-default"]
61+
## Enables several text file formats for serialization and deserialization
62+
serde-default = ["msgpack"]
6063
## Exposes ALTRIOS structs, methods, and functions to Python.
6164
pyo3 = ["dep:pyo3"]
65+
## Enables message pack serialization and deserialization via `rmp-serde`
66+
msgpack = ["dep:rmp-serde"]
6267

6368
[lints.rust]
6469
# `'cfg(debug_advance_rewind)'` is expected for debugging in `advance_rewind.rs`

0 commit comments

Comments
 (0)