diff --git a/.github/workflows/py-tests.yaml b/.github/workflows/py-tests.yaml index 0e46f531..16dcf923 100644 --- a/.github/workflows/py-tests.yaml +++ b/.github/workflows/py-tests.yaml @@ -26,7 +26,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.9', '3.10', '3.11'] + python-version: ['3.10', '3.11'] env: PYTHON: ${{ matrix.python-version }} diff --git a/.github/workflows/wheels.yaml b/.github/workflows/wheels.yaml index 2e23a889..03d2dfb1 100644 --- a/.github/workflows/wheels.yaml +++ b/.github/workflows/wheels.yaml @@ -21,7 +21,6 @@ jobs: - macos - windows python-version: - - "9" - "10" - "11" include: @@ -36,7 +35,7 @@ jobs: - name: set up python uses: actions/setup-python@v4 with: - python-version: "3.9" + python-version: "3.11" - name: set up rust if: matrix.os != 'ubuntu' diff --git a/pyproject.toml b/pyproject.toml index c6083162..32ea8b14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ authors = [ description = "Tool for modeling and optimization of advanced locomotive powertrains for freight rail decarbonization." readme = "README.md" license = { file = "LICENSE.md" } -requires-python = ">=3.9, <3.12" +requires-python = ">=3.10, <3.12" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: BSD License", @@ -47,6 +47,7 @@ dependencies = [ "pyarrow", "requests", "PyYAML==6.0.2", + "msgpack==1.1.0", ] [project.urls] diff --git a/python/altrios/__init__.py b/python/altrios/__init__.py index 6c8c4fb3..fe855bbd 100644 --- a/python/altrios/__init__.py +++ b/python/altrios/__init__.py @@ -119,27 +119,76 @@ def history_path_list(self, element_as_list:bool=False) -> List[str]: item for item in self.variable_path_list( element_as_list=element_as_list) if "history" in item_str(item) ] - return history_path_list - -def to_pydict(self) -> Dict: + return history_path_list + +# TODO connect to crate features +data_formats = [ + 'yaml', + 'msg_pack', + # 'toml', + 'json', +] + +def to_pydict(self, data_fmt: str = "msg_pack", flatten: bool = False) -> Dict: """ Returns self converted to pure python dictionary with no nested Rust objects + # Arguments + - `flatten`: if True, returns dict without any hierarchy + - `data_fmt`: data format for intermediate conversion step """ - from yaml import load - try: - from yaml import CLoader as Loader - except ImportError: - from yaml import Loader - pydict = load(self.to_yaml(), Loader = Loader) - return pydict + data_fmt = data_fmt.lower() + assert data_fmt in data_formats, f"`data_fmt` must be one of {data_formats}" + match data_fmt: + case "msg_pack": + import msgpack + pydict = msgpack.loads(self.to_msg_pack()) + case "yaml": + from yaml import load + try: + from yaml import CLoader as Loader + except ImportError: + from yaml import Loader + pydict = load(self.to_yaml(), Loader=Loader) + case "json": + from json import loads + pydict = loads(self.to_json()) + + if not flatten: + return pydict + else: + return next(iter(pd.json_normalize(pydict, sep=".").to_dict(orient='records'))) @classmethod -def from_pydict(cls, pydict: Dict) -> Self: +def from_pydict(cls, pydict: Dict, data_fmt: str = "msg_pack", skip_init: bool = True) -> Self: """ Instantiates Self from pure python dictionary + # Arguments + - `pydict`: dictionary to be converted to ALTRIOS object + - `data_fmt`: data format for intermediate conversion step + - `skip_init`: passed to `SerdeAPI` methods to control whether initialization + is skipped """ - import yaml - return cls.from_yaml(yaml.dump(pydict),skip_init=False) + data_fmt = data_fmt.lower() + assert data_fmt in data_formats, f"`data_fmt` must be one of {data_formats}" + match data_fmt.lower(): + case "yaml": + import yaml + obj = cls.from_yaml(yaml.dump(pydict), skip_init=skip_init) + case "msg_pack": + import msgpack + try: + obj = cls.from_msg_pack( + msgpack.packb(pydict), skip_init=skip_init) + except Exception as err: + print( + f"{err}\nFalling back to YAML.") + obj = cls.from_pydict( + pydict, data_fmt="yaml", skip_init=skip_init) + case "json": + from json import dumps + obj = cls.from_json(dumps(pydict), skip_init=skip_init) + + return obj def to_dataframe(self, pandas:bool=False) -> [pd.DataFrame, pl.DataFrame, pl.LazyFrame]: """ diff --git a/python/altrios/altrios_pyo3.pyi b/python/altrios/altrios_pyo3.pyi index d16a7940..f6cbd04d 100644 --- a/python/altrios/altrios_pyo3.pyi +++ b/python/altrios/altrios_pyo3.pyi @@ -16,11 +16,14 @@ class SerdeAPI(object): @classmethod def from_yaml(cls) -> Self: ... @classmethod - def from_file(cls) -> Self: ... + def from_file(cls, skip_init=False) -> Self: ... def to_file(self): ... def to_bincode(self) -> bytes: ... def to_json(self) -> str: ... def to_yaml(self) -> str: ... + def to_pydict(self, data_fmt: str = "msg_pack", flatten: bool = False) -> Dict: ... + @classmethod + def from_pydict(cls, pydict: Dict, data_fmt: str = "msg_pack") -> Self: class Consist(SerdeAPI): diff --git a/python/altrios/demos/sim_manager_demo.py b/python/altrios/demos/sim_manager_demo.py index 65d99b37..f1975738 100644 --- a/python/altrios/demos/sim_manager_demo.py +++ b/python/altrios/demos/sim_manager_demo.py @@ -23,7 +23,7 @@ t0_import = time.perf_counter() t0_total = time.perf_counter() -rail_vehicles=[alt.RailVehicle.from_file(vehicle_file) +rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False) for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')] location_map = alt.import_locations(alt.resources_root() / "networks/default_locations.csv") diff --git a/python/altrios/demos/version_migration_demo.py b/python/altrios/demos/version_migration_demo.py index 8f83e47d..0c059367 100644 --- a/python/altrios/demos/version_migration_demo.py +++ b/python/altrios/demos/version_migration_demo.py @@ -13,8 +13,8 @@ def migrate_network() -> Tuple[alt.Network, alt.Network]: old_network_path = alt.resources_root() / "networks/Taconite_v0.1.6.yaml" new_network_path = alt.resources_root() / "networks/Taconite.yaml" - network_from_old = alt.Network.from_file(old_network_path) - network_from_new = alt.Network.from_file(new_network_path) + network_from_old = alt.Network.from_file(old_network_path, skip_init=False) + network_from_new = alt.Network.from_file(new_network_path, skip_init=False) # `network_from_old` could be used to overwrite the file in the new format with # ``` diff --git a/python/altrios/rollout.py b/python/altrios/rollout.py index 6e8b430b..f841ddeb 100644 --- a/python/altrios/rollout.py +++ b/python/altrios/rollout.py @@ -59,13 +59,13 @@ def simulate_prescribed_rollout( else: demand_paths.append(demand_file) - rail_vehicles=[alt.RailVehicle.from_file(vehicle_file) + rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False) for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')] location_map = alt.import_locations( str(alt.resources_root() / "networks/default_locations.csv") ) - network = alt.Network.from_file(network_filename_path) + network = alt.Network.from_file(network_filename_path, skip_init=False) sim_days = defaults.SIMULATION_DAYS scenarios = [] for idx, scenario_year in enumerate(years): diff --git a/python/altrios/tests/test_serde.py b/python/altrios/tests/test_serde.py new file mode 100644 index 00000000..d9d5c0bb --- /dev/null +++ b/python/altrios/tests/test_serde.py @@ -0,0 +1,125 @@ +import time +import altrios as alt + +SAVE_INTERVAL = 100 +def get_solved_speed_limit_train_sim(): + # Build the train config + rail_vehicle_loaded = alt.RailVehicle.from_file( + alt.resources_root() / "rolling_stock/Manifest_Loaded.yaml") + rail_vehicle_empty = alt.RailVehicle.from_file( + alt.resources_root() / "rolling_stock/Manifest_Empty.yaml") + + # https://docs.rs/altrios-core/latest/altrios_core/train/struct.TrainConfig.html + train_config = alt.TrainConfig( + rail_vehicles=[rail_vehicle_loaded, rail_vehicle_empty], + n_cars_by_type={ + "Manifest_Loaded": 50, + "Manifest_Empty": 50, + }, + train_length_meters=None, + train_mass_kilograms=None, + ) + + # Build the locomotive consist model + # instantiate battery model + # https://docs.rs/altrios-core/latest/altrios_core/consist/locomotive/powertrain/reversible_energy_storage/struct.ReversibleEnergyStorage.html# + res = alt.ReversibleEnergyStorage.from_file( + alt.resources_root() / "powertrains/reversible_energy_storages/Kokam_NMC_75Ah_flx_drive.yaml" + ) + + edrv = alt.ElectricDrivetrain( + pwr_out_frac_interp=[0., 1.], + eta_interp=[0.98, 0.98], + pwr_out_max_watts=5e9, + save_interval=SAVE_INTERVAL, + ) + + bel: alt.Locomotive = alt.Locomotive.build_battery_electric_loco( + reversible_energy_storage=res, + drivetrain=edrv, + loco_params=alt.LocoParams.from_dict(dict( + pwr_aux_offset_watts=8.55e3, + pwr_aux_traction_coeff_ratio=540.e-6, + force_max_newtons=667.2e3, + ))) + + # construct a vector of one BEL and several conventional locomotives + loco_vec = [bel.clone()] + [alt.Locomotive.default()] * 7 + # instantiate consist + loco_con = alt.Consist( + loco_vec + ) + + # Instantiate the intermediate `TrainSimBuilder` + tsb = alt.TrainSimBuilder( + train_id="0", + origin_id="A", + destination_id="B", + train_config=train_config, + loco_con=loco_con, + ) + + # Load the network and construct the timed link path through the network. + network = alt.Network.from_file( + alt.resources_root() / 'networks/simple_corridor_network.yaml') + + location_map = alt.import_locations( + alt.resources_root() / "networks/simple_corridor_locations.csv") + train_sim: alt.SetSpeedTrainSim = tsb.make_speed_limit_train_sim( + location_map=location_map, + save_interval=1, + ) + train_sim.set_save_interval(SAVE_INTERVAL) + est_time_net, _consist = alt.make_est_times(train_sim, network) + + timed_link_path = alt.run_dispatch( + network, + alt.SpeedLimitTrainSimVec([train_sim]), + [est_time_net], + False, + False, + )[0] + + train_sim.walk_timed_path( + network=network, + timed_path=timed_link_path, + ) + assert len(train_sim.history) > 1 + + return train_sim + + +def test_pydict(): + ts = get_solved_speed_limit_train_sim() + + t0 = time.perf_counter_ns() + ts_dict_msg = ts.to_pydict(flatten=False, data_fmt="msg_pack") + ts_msg = alt.SpeedLimitTrainSim.from_pydict( + ts_dict_msg, data_fmt="msg_pack") + t1 = time.perf_counter_ns() + t_msg = t1 - t0 + print(f"\nElapsed time for MessagePack: {t_msg:.3e} ns ") + + t0 = time.perf_counter_ns() + ts_dict_yaml = ts.to_pydict(flatten=False, data_fmt="yaml") + ts_yaml = alt.SpeedLimitTrainSim.from_pydict(ts_dict_yaml, data_fmt="yaml") + t1 = time.perf_counter_ns() + t_yaml = t1 - t0 + print(f"Elapsed time for YAML: {t_yaml:.3e} ns ") + print(f"YAML time per MessagePack time: {(t_yaml / t_msg):.3e} ") + + t0 = time.perf_counter_ns() + ts_dict_json = ts.to_pydict(flatten=False, data_fmt="json") + _ts_json = alt.SpeedLimitTrainSim.from_pydict( + ts_dict_json, data_fmt="json") + t1 = time.perf_counter_ns() + t_json = t1 - t0 + print(f"Elapsed time for json: {t_json:.3e} ns ") + print(f"JSON time per MessagePack time: {(t_json / t_msg):.3e} ") + + # `to_pydict` is necessary because of some funkiness with direct equality comparison + assert ts_msg.to_pydict() == ts.to_pydict() + assert ts_yaml.to_pydict() == ts.to_pydict() + +if __name__ == "__main__": + test_pydict() diff --git a/python/altrios/train_planner.py b/python/altrios/train_planner.py index e8165e6f..e906c212 100644 --- a/python/altrios/train_planner.py +++ b/python/altrios/train_planner.py @@ -1129,7 +1129,7 @@ def run_train_planner( if __name__ == "__main__": - rail_vehicles=[alt.RailVehicle.from_file(vehicle_file) + rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False) for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')] location_map = alt.import_locations( diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 93cf303b..d39ac947 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -70,6 +70,7 @@ dependencies = [ "pyo3-polars", "rayon", "readonly", + "rmp-serde", "serde", "serde-this-or-that", "serde_json", @@ -2173,6 +2174,28 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "rmp" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rustc-demangle" version = "0.1.24" diff --git a/rust/altrios-core/Cargo.toml b/rust/altrios-core/Cargo.toml index 7d46e129..7dd3f70b 100644 --- a/rust/altrios-core/Cargo.toml +++ b/rust/altrios-core/Cargo.toml @@ -13,6 +13,7 @@ rust-version = { workspace = true } [dependencies] csv = "1.1.6" serde = { version = "1.0.136", features = ["derive"] } +rmp-serde = { version = "1.3.0", optional = true } serde_yaml = "0.8.23" serde_json = "1.0" uom = { workspace = true, features = ["use_serde"] } @@ -56,9 +57,13 @@ tempfile = "3.10.1" derive_more = { version = "1.0.0", features = ["from_str", "from", "is_variant", "try_into"] } [features] -default = [] +default = ["serde-default"] +## Enables several text file formats for serialization and deserialization +serde-default = ["msgpack"] ## Exposes ALTRIOS structs, methods, and functions to Python. pyo3 = ["dep:pyo3"] +## Enables message pack serialization and deserialization via `rmp-serde` +msgpack = ["dep:rmp-serde"] [lints.rust] # `'cfg(debug_advance_rewind)'` is expected for debugging in `advance_rewind.rs` diff --git a/rust/altrios-core/altrios-proc-macros/src/altrios_api/mod.rs b/rust/altrios-core/altrios-proc-macros/src/altrios_api/mod.rs index 90387240..521d6bcb 100644 --- a/rust/altrios-core/altrios-proc-macros/src/altrios_api/mod.rs +++ b/rust/altrios-core/altrios-proc-macros/src/altrios_api/mod.rs @@ -38,11 +38,18 @@ pub(crate) fn altrios_api(attr: TokenStream, item: TokenStream) -> TokenStream { self.to_str(format) } - /// See [SerdeAPI::from_str] + /// Read (deserialize) an object from a string + /// + /// # Arguments: + /// + /// * `contents`: `str` - The string containing the object data + /// * `format`: `str` - The source format, any of those listed in [`ACCEPTED_STR_FORMATS`](`SerdeAPI::ACCEPTED_STR_FORMATS`) + /// #[staticmethod] #[pyo3(name = "from_str")] - pub fn from_str_py(contents: &str, format: &str) -> anyhow::Result { - Self::from_str(contents, format) + #[pyo3(signature = (contents, format, skip_init=None))] + pub fn from_str_py(contents: &str, format: &str, skip_init: Option) -> PyResult { + Ok(SerdeAPI::from_str(contents, format, skip_init.unwrap_or_default())?) } /// See [SerdeAPI::to_json] @@ -51,11 +58,17 @@ pub(crate) fn altrios_api(attr: TokenStream, item: TokenStream) -> TokenStream { self.to_json() } - /// See [SerdeAPI::from_json] + /// Read (deserialize) an object from a JSON string + /// + /// # Arguments + /// + /// * `json_str`: `str` - JSON-formatted string to deserialize from + /// #[staticmethod] #[pyo3(name = "from_json")] - fn from_json_py(json_str: &str) -> anyhow::Result { - Self::from_json(json_str) + #[pyo3(signature = (json_str, skip_init=None))] + pub fn from_json_py(json_str: &str, skip_init: Option) -> PyResult { + Ok(Self::from_json(json_str, skip_init.unwrap_or_default())?) } /// See [SerdeAPI::to_yaml] @@ -64,11 +77,41 @@ pub(crate) fn altrios_api(attr: TokenStream, item: TokenStream) -> TokenStream { self.to_yaml() } - /// See [SerdeAPI::from_yaml] + /// Read (deserialize) an object from a YAML string + /// + /// # Arguments + /// + /// * `yaml_str`: `str` - YAML-formatted string to deserialize from + /// #[staticmethod] #[pyo3(name = "from_yaml")] - fn from_yaml_py(yaml_str: &str) -> anyhow::Result { - Self::from_yaml(yaml_str) + #[pyo3(signature = (yaml_str, skip_init=None))] + pub fn from_yaml_py(yaml_str: &str, skip_init: Option) -> PyResult { + Ok(Self::from_yaml(yaml_str, skip_init.unwrap_or_default())?) + } + + /// Write (serialize) an object to a message pack + #[cfg(feature = "msgpack")] + #[pyo3(name = "to_msg_pack")] + // TODO: figure from Kyle out how to use `PyIOError` + pub fn to_msg_pack_py<'py>(&self, py: Python<'py>) -> anyhow::Result> { + Ok(PyBytes::new_bound(py, &self.to_msg_pack()?)) + } + + /// Read (deserialize) an object from a message pack + /// + /// # Arguments + /// * `msg_pack`: message pack + #[cfg(feature = "msgpack")] + #[staticmethod] + #[pyo3(name = "from_msg_pack")] + #[pyo3(signature = (msg_pack, skip_init=None))] + // TODO: figure from Kyle out how to use `PyIOError` + pub fn from_msg_pack_py(msg_pack: &Bound, skip_init: Option) -> anyhow::Result { + Self::from_msg_pack( + msg_pack.as_bytes(), + skip_init.unwrap_or_default() + ) } /// See [SerdeAPI::to_bincode] @@ -136,14 +179,15 @@ pub(crate) fn altrios_api(attr: TokenStream, item: TokenStream) -> TokenStream { /// #[staticmethod] #[pyo3(name = "from_file")] - pub fn from_file_py(filepath: &Bound) -> anyhow::Result { - Self::from_file(PathBuf::extract_bound(filepath)?) + #[pyo3(signature = (filepath, skip_init=None))] + pub fn from_file_py(filepath: &Bound, skip_init: Option) -> PyResult { + Ok(Self::from_file(PathBuf::extract_bound(filepath)?, skip_init.unwrap_or_default())?) } } }; let mut final_output = TokenStream2::default(); final_output.extend::(quote! { - #[cfg_attr(feature="pyo3", pyclass(module="altrios_pyo3", subclass))] + #[cfg_attr(feature="pyo3", pyclass(module="altrios_pyo3", subclass, eq))] }); let mut output: TokenStream2 = ast.to_token_stream(); output.extend(impl_block); diff --git a/rust/altrios-core/src/consist/consist_model.rs b/rust/altrios-core/src/consist/consist_model.rs index 2a217fcf..e175f808 100644 --- a/rust/altrios-core/src/consist/consist_model.rs +++ b/rust/altrios-core/src/consist/consist_model.rs @@ -58,7 +58,7 @@ use super::*; ); } - fn get_hct(&self) -> String { + fn get_pdct(&self) -> String { // make a `describe` function match &self.pdct { PowerDistributionControlType::RESGreedy(val) => format!("{val:?}"), @@ -115,7 +115,7 @@ pub struct Consist { #[serde(default)] #[serde(skip_serializing_if = "EqDefault::eq_default")] pub state: ConsistState, - #[serde(skip_serializing_if = "ConsistStateHistoryVec::is_empty", default)] + #[serde(default, skip_serializing_if = "ConsistStateHistoryVec::is_empty")] /// Custom vector of [Self::state] pub history: ConsistStateHistoryVec, #[api(skip_set, skip_get)] // custom needed for this diff --git a/rust/altrios-core/src/consist/consist_utils.rs b/rust/altrios-core/src/consist/consist_utils.rs index 13f669b0..3a6055da 100644 --- a/rust/altrios-core/src/consist/consist_utils.rs +++ b/rust/altrios-core/src/consist/consist_utils.rs @@ -267,6 +267,7 @@ impl SolvePower for FrontAndBack { todo!() // not needed urgently } } + /// Variants of this enum are used to determine what control strategy gets used for distributing /// power required from or delivered to during negative tractive power each locomotive. #[derive(PartialEq, Clone, Deserialize, Serialize, Debug, SerdeAPI)] diff --git a/rust/altrios-core/src/consist/locomotive/locomotive_model.rs b/rust/altrios-core/src/consist/locomotive/locomotive_model.rs index 13751580..453996c4 100644 --- a/rust/altrios-core/src/consist/locomotive/locomotive_model.rs +++ b/rust/altrios-core/src/consist/locomotive/locomotive_model.rs @@ -617,7 +617,7 @@ pub struct Locomotive { #[api(skip_set, skip_get)] save_interval: Option, /// Custom vector of [Self::state] - #[serde(skip_serializing_if = "LocomotiveStateHistoryVec::is_empty", default)] + #[serde(default, skip_serializing_if = "LocomotiveStateHistoryVec::is_empty")] pub history: LocomotiveStateHistoryVec, #[serde(default = "utils::return_true")] /// If true, requires power demand to not exceed consist diff --git a/rust/altrios-core/src/consist/locomotive/powertrain/electric_drivetrain.rs b/rust/altrios-core/src/consist/locomotive/powertrain/electric_drivetrain.rs index 244af0d2..a38dddb5 100644 --- a/rust/altrios-core/src/consist/locomotive/powertrain/electric_drivetrain.rs +++ b/rust/altrios-core/src/consist/locomotive/powertrain/electric_drivetrain.rs @@ -77,8 +77,8 @@ pub struct ElectricDrivetrain { pub save_interval: Option, /// Custom vector of [Self::state] #[serde( - skip_serializing_if = "ElectricDrivetrainStateHistoryVec::is_empty", - default + default, + skip_serializing_if = "ElectricDrivetrainStateHistoryVec::is_empty" )] pub history: ElectricDrivetrainStateHistoryVec, } @@ -250,7 +250,7 @@ impl Default for ElectricDrivetrain { fn default() -> Self { // let file_contents = include_str!(EDRV_DEFAULT_PATH_STR); let file_contents = include_str!("electric_drivetrain.default.yaml"); - Self::from_yaml(file_contents).unwrap() + Self::from_yaml(file_contents, false).unwrap() } } diff --git a/rust/altrios-core/src/consist/locomotive/powertrain/fuel_converter.rs b/rust/altrios-core/src/consist/locomotive/powertrain/fuel_converter.rs index 11d06b27..c611db68 100644 --- a/rust/altrios-core/src/consist/locomotive/powertrain/fuel_converter.rs +++ b/rust/altrios-core/src/consist/locomotive/powertrain/fuel_converter.rs @@ -75,8 +75,8 @@ pub struct FuelConverter { pub save_interval: Option, /// Custom vector of [Self::state] #[serde( - skip_serializing_if = "FuelConverterStateHistoryVec::is_empty", - default + default, + skip_serializing_if = "FuelConverterStateHistoryVec::is_empty" )] pub history: FuelConverterStateHistoryVec, // TODO: spec out fuel tank size and track kg of fuel } @@ -84,7 +84,7 @@ pub struct FuelConverter { impl Default for FuelConverter { fn default() -> Self { let file_contents = include_str!("fuel_converter.default.yaml"); - Self::from_yaml(file_contents).unwrap() + Self::from_yaml(file_contents, false).unwrap() } } diff --git a/rust/altrios-core/src/consist/locomotive/powertrain/generator.rs b/rust/altrios-core/src/consist/locomotive/powertrain/generator.rs index 10fe88bd..356f8250 100644 --- a/rust/altrios-core/src/consist/locomotive/powertrain/generator.rs +++ b/rust/altrios-core/src/consist/locomotive/powertrain/generator.rs @@ -102,7 +102,7 @@ pub struct Generator { /// Time step interval between saves. 1 is a good option. If None, no saving occurs. pub save_interval: Option, /// Custom vector of [Self::state] - #[serde(skip_serializing_if = "GeneratorStateHistoryVec::is_empty", default)] + #[serde(default, skip_serializing_if = "GeneratorStateHistoryVec::is_empty")] pub history: GeneratorStateHistoryVec, } @@ -306,7 +306,7 @@ impl Generator { impl Default for Generator { fn default() -> Self { let file_contents = include_str!("generator.default.yaml"); - Self::from_yaml(file_contents).unwrap() + Self::from_yaml(file_contents, false).unwrap() } } diff --git a/rust/altrios-core/src/consist/locomotive/powertrain/reversible_energy_storage.rs b/rust/altrios-core/src/consist/locomotive/powertrain/reversible_energy_storage.rs index 119464f8..0e1c8438 100644 --- a/rust/altrios-core/src/consist/locomotive/powertrain/reversible_energy_storage.rs +++ b/rust/altrios-core/src/consist/locomotive/powertrain/reversible_energy_storage.rs @@ -187,8 +187,8 @@ pub struct ReversibleEnergyStorage { /// Time step interval at which history is saved pub save_interval: Option, #[serde( - skip_serializing_if = "ReversibleEnergyStorageStateHistoryVec::is_empty", - default + default, + skip_serializing_if = "ReversibleEnergyStorageStateHistoryVec::is_empty" )] /// Custom vector of [Self::state] pub history: ReversibleEnergyStorageStateHistoryVec, @@ -197,7 +197,7 @@ pub struct ReversibleEnergyStorage { impl Default for ReversibleEnergyStorage { fn default() -> Self { let file_contents = include_str!("reversible_energy_storage.default.yaml"); - let mut res = Self::from_yaml(file_contents).unwrap(); + let mut res = Self::from_yaml(file_contents, false).unwrap(); res.state.soc = res.max_soc; res } diff --git a/rust/altrios-core/src/meet_pass/dispatch.rs b/rust/altrios-core/src/meet_pass/dispatch.rs index e7a9f5b0..7b3573d3 100644 --- a/rust/altrios-core/src/meet_pass/dispatch.rs +++ b/rust/altrios-core/src/meet_pass/dispatch.rs @@ -304,7 +304,7 @@ mod test_dispatch { let network_file_path = project_root::get_project_root() .unwrap() .join("../python/altrios/resources/networks/Taconite.yaml"); - let network = Network::from_file(network_file_path).unwrap(); + let network = Network::from_file(network_file_path, false).unwrap(); let train_sims = vec![ crate::train::speed_limit_train_sim_fwd(), diff --git a/rust/altrios-core/src/meet_pass/train_disp/mod.rs b/rust/altrios-core/src/meet_pass/train_disp/mod.rs index 0263992c..4ebe48d4 100644 --- a/rust/altrios-core/src/meet_pass/train_disp/mod.rs +++ b/rust/altrios-core/src/meet_pass/train_disp/mod.rs @@ -223,7 +223,7 @@ mod test_train_disp { let network_file_path = project_root::get_project_root() .unwrap() .join("../python/altrios/resources/networks/Taconite.yaml"); - let network = Network::from_file(network_file_path).unwrap(); + let network = Network::from_file(network_file_path, false).unwrap(); let speed_limit_train_sim = crate::train::speed_limit_train_sim_fwd(); let est_times = make_est_times(speed_limit_train_sim.clone(), network) @@ -248,7 +248,7 @@ mod test_train_disp { let network_file_path = project_root::get_project_root() .unwrap() .join("../python/altrios/resources/networks/Taconite.yaml"); - let network = Network::from_file(network_file_path).unwrap(); + let network = Network::from_file(network_file_path, false).unwrap(); let speed_limit_train_sim = crate::train::speed_limit_train_sim_rev(); let est_times = make_est_times(speed_limit_train_sim.clone(), network) diff --git a/rust/altrios-core/src/track/link/link_impl.rs b/rust/altrios-core/src/track/link/link_impl.rs index 11a2f75f..d8225615 100644 --- a/rust/altrios-core/src/track/link/link_impl.rs +++ b/rust/altrios-core/src/track/link/link_impl.rs @@ -310,22 +310,22 @@ impl ObjState for Network { } impl SerdeAPI for Network { - fn from_file>(filepath: P) -> anyhow::Result { + fn from_file>(filepath: P, skip_init: bool) -> anyhow::Result { let filepath = filepath.as_ref(); let extension = filepath .extension() .and_then(OsStr::to_str) .with_context(|| format!("File extension could not be parsed: {filepath:?}"))?; - let file = File::open(filepath).with_context(|| { + let mut file = File::open(filepath).with_context(|| { if !filepath.exists() { format!("File not found: {filepath:?}") } else { format!("Could not open file: {filepath:?}") } })?; - let mut network = match Self::from_reader(file, extension) { + let mut network = match Self::from_reader(&mut file, extension, skip_init) { Ok(network) => network, - Err(err) => NetworkOld::from_file(filepath) + Err(err) => NetworkOld::from_file(filepath, false) .map_err(|old_err| { anyhow!("\nattempting to load as `Network`:\n{}\nattempting to load as `NetworkOld`:\n{}", err, old_err) })? @@ -617,7 +617,10 @@ mod tests { let tempdir = tempfile::tempdir().unwrap(); let temp_file_path = tempdir.path().join("links_test2.yaml"); links.to_file(temp_file_path.clone()).unwrap(); - assert_eq!(Vec::::from_file(temp_file_path).unwrap(), links); + assert_eq!( + Vec::::from_file(temp_file_path, false).unwrap(), + links + ); tempdir.close().unwrap(); } @@ -626,7 +629,7 @@ mod tests { let network_file_path = project_root::get_project_root() .unwrap() .join("../python/altrios/resources/networks/Taconite.yaml"); - let network_speed_sets = Network::from_file(network_file_path).unwrap(); + let network_speed_sets = Network::from_file(network_file_path, false).unwrap(); let mut network_speed_set = network_speed_sets.clone(); network_speed_set .set_speed_set_for_train_type(TrainType::Freight) diff --git a/rust/altrios-core/src/track/link/speed/speed_set.rs b/rust/altrios-core/src/track/link/speed/speed_set.rs index 4498aff7..ca50786b 100644 --- a/rust/altrios-core/src/track/link/speed/speed_set.rs +++ b/rust/altrios-core/src/track/link/speed/speed_set.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; #[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, SerdeAPI, Hash)] #[repr(u8)] -#[cfg_attr(feature = "pyo3", pyclass(eq, eq_int))] +#[cfg_attr(feature = "pyo3", pyclass(eq))] /// Enum with variants representing train types pub enum TrainType { #[default] diff --git a/rust/altrios-core/src/train/friction_brakes.rs b/rust/altrios-core/src/train/friction_brakes.rs index c982238f..b86aa6f7 100644 --- a/rust/altrios-core/src/train/friction_brakes.rs +++ b/rust/altrios-core/src/train/friction_brakes.rs @@ -43,7 +43,7 @@ pub struct FricBrake { #[serde(default)] #[serde(skip_serializing_if = "EqDefault::eq_default")] pub state: FricBrakeState, - #[serde(skip_serializing_if = "FricBrakeStateHistoryVec::is_empty", default)] + #[serde(default, skip_serializing_if = "FricBrakeStateHistoryVec::is_empty")] /// Custom vector of [Self::state] pub history: FricBrakeStateHistoryVec, pub save_interval: Option, diff --git a/rust/altrios-core/src/train/set_speed_train_sim.rs b/rust/altrios-core/src/train/set_speed_train_sim.rs index 75ec44ee..1b15d746 100644 --- a/rust/altrios-core/src/train/set_speed_train_sim.rs +++ b/rust/altrios-core/src/train/set_speed_train_sim.rs @@ -205,11 +205,11 @@ pub struct SpeedTraceElement { save_interval: Option, ) -> Self { let path_tpc = match path_tpc_file { - Some(file) => PathTpc::from_file(file).unwrap(), + Some(file) => PathTpc::from_file(file, false).unwrap(), None => PathTpc::valid() }; let train_res = match train_res_file { - Some(file) => TrainRes::from_file(file).unwrap(), + Some(file) => TrainRes::from_file(file, false).unwrap(), None => TrainRes::valid() }; @@ -290,7 +290,7 @@ pub struct SetSpeedTrainSim { #[api(skip_set)] path_tpc: PathTpc, /// Custom vector of [Self::state] - #[serde(skip_serializing_if = "TrainStateHistoryVec::is_empty", default)] + #[serde(default, skip_serializing_if = "TrainStateHistoryVec::is_empty")] pub history: TrainStateHistoryVec, #[api(skip_set, skip_get)] save_interval: Option, diff --git a/rust/altrios-core/src/train/speed_limit_train_sim.rs b/rust/altrios-core/src/train/speed_limit_train_sim.rs index 0b81c7fa..e9cfe43d 100644 --- a/rust/altrios-core/src/train/speed_limit_train_sim.rs +++ b/rust/altrios-core/src/train/speed_limit_train_sim.rs @@ -101,7 +101,7 @@ impl From<&Vec> for TimedLinkPath { #[pyo3(name = "extend_path")] pub fn extend_path_py(&mut self, network_file_path: String, link_path: Vec) -> anyhow::Result<()> { - let network = Vec::::from_file(network_file_path).unwrap(); + let network = Vec::::from_file(network_file_path, false).unwrap(); self.extend_path(&network, &link_path)?; Ok(()) @@ -153,7 +153,7 @@ pub struct SpeedLimitTrainSim { pub braking_points: BrakingPoints, pub fric_brake: FricBrake, /// Custom vector of [Self::state] - #[serde(skip_serializing_if = "TrainStateHistoryVec::is_empty", default)] + #[serde(default, skip_serializing_if = "TrainStateHistoryVec::is_empty")] pub history: TrainStateHistoryVec, #[api(skip_set, skip_get)] save_interval: Option, diff --git a/rust/altrios-core/src/traits.rs b/rust/altrios-core/src/traits.rs index 51225a77..81afe4fc 100644 --- a/rust/altrios-core/src/traits.rs +++ b/rust/altrios-core/src/traits.rs @@ -150,20 +150,20 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> { /// /// * `filepath`: The filepath from which to read the object /// - fn from_file>(filepath: P) -> anyhow::Result { + fn from_file>(filepath: P, skip_init: bool) -> anyhow::Result { let filepath = filepath.as_ref(); let extension = filepath .extension() .and_then(OsStr::to_str) .with_context(|| format!("File extension could not be parsed: {filepath:?}"))?; - let file = File::open(filepath).with_context(|| { + let mut file = File::open(filepath).with_context(|| { if !filepath.exists() { format!("File not found: {filepath:?}") } else { format!("Could not open file: {filepath:?}") } })?; - Self::from_reader(file, extension) + Self::from_reader(&mut file, extension, skip_init) } /// Write (serialize) an object into a string @@ -190,11 +190,11 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> { /// * `contents` - The string containing the object data /// * `format` - The source format, any of those listed in [`ACCEPTED_STR_FORMATS`](`SerdeAPI::ACCEPTED_STR_FORMATS`) /// - fn from_str>(contents: S, format: &str) -> anyhow::Result { + fn from_str>(contents: S, format: &str, skip_init: bool) -> anyhow::Result { Ok( match format.trim_start_matches('.').to_lowercase().as_str() { - "yaml" | "yml" => Self::from_yaml(contents)?, - "json" => Self::from_json(contents)?, + "yaml" | "yml" => Self::from_yaml(contents, skip_init)?, + "json" => Self::from_json(contents, skip_init)?, _ => bail!( "Unsupported format {format:?}, must be one of {:?}", Self::ACCEPTED_STR_FORMATS @@ -210,20 +210,24 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> { /// * `rdr` - The reader from which to read object data /// * `format` - The source format, any of those listed in [`ACCEPTED_BYTE_FORMATS`](`SerdeAPI::ACCEPTED_BYTE_FORMATS`) /// - fn from_reader(rdr: R, format: &str) -> anyhow::Result - where - R: std::io::Read, - { + fn from_reader( + rdr: &mut R, + format: &str, + skip_init: bool, + ) -> anyhow::Result { let mut deserialized: Self = match format.trim_start_matches('.').to_lowercase().as_str() { "yaml" | "yml" => serde_yaml::from_reader(rdr)?, "json" => serde_json::from_reader(rdr)?, - "bin" => bincode::deserialize_from(rdr)?, + #[cfg(feature = "msgpack")] + "msgpack" => rmp_serde::decode::from_read(rdr)?, _ => bail!( "Unsupported format {format:?}, must be one of {:?}", Self::ACCEPTED_BYTE_FORMATS ), }; - deserialized.init()?; + if !skip_init { + deserialized.init()?; + } Ok(deserialized) } @@ -232,15 +236,17 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> { Ok(serde_json::to_string(&self)?) } - /// Read (deserialize) an object to a JSON string + /// Read (deserialize) an object from a JSON string /// /// # Arguments /// /// * `json_str` - JSON-formatted string to deserialize from /// - fn from_json>(json_str: S) -> anyhow::Result { + fn from_json>(json_str: S, skip_init: bool) -> anyhow::Result { let mut json_de: Self = serde_json::from_str(json_str.as_ref())?; - json_de.init()?; + if !skip_init { + json_de.init()?; + } Ok(json_de) } @@ -249,15 +255,38 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> { Ok(serde_yaml::to_string(&self)?) } + /// Write (serialize) an object to a message pack + #[cfg(feature = "msgpack")] + fn to_msg_pack(&self) -> anyhow::Result> { + Ok(rmp_serde::encode::to_vec_named(&self)?) + } + + /// Read (deserialize) an object from a message pack + /// + /// # Arguments + /// + /// * `msg_pack` - message pack object + /// + #[cfg(feature = "msgpack")] + fn from_msg_pack(msg_pack: &[u8], skip_init: bool) -> anyhow::Result { + let mut msg_pack_de: Self = rmp_serde::decode::from_slice(msg_pack)?; + if !skip_init { + msg_pack_de.init()?; + } + Ok(msg_pack_de) + } + /// Read (deserialize) an object from a YAML string /// /// # Arguments /// /// * `yaml_str` - YAML-formatted string to deserialize from /// - fn from_yaml>(yaml_str: S) -> anyhow::Result { + fn from_yaml>(yaml_str: S, skip_init: bool) -> anyhow::Result { let mut yaml_de: Self = serde_yaml::from_str(yaml_str.as_ref())?; - yaml_de.init()?; + if !skip_init { + yaml_de.init()?; + } Ok(yaml_de) }