diff --git a/ffi/py/tests/mobilenet_onnx_test.py b/ffi/py/tests/mobilenet_onnx_test.py index 167a5a00b8..6b0265fc13 100644 --- a/ffi/py/tests/mobilenet_onnx_test.py +++ b/ffi/py/tests/mobilenet_onnx_test.py @@ -123,6 +123,7 @@ def test_typed_model_to_nnef_and_back(): assert str(reloaded.output_fact(0)) == "B,1000,F32" path = tmpdirname / "nnef.tar.gz" + nnef = nnef.with_extended_identifier_syntax() nnef.write_model_to_tar_gz(typed, path) reloaded = nnef.model_for_path(path) assert str(reloaded.input_fact(0)) == "B,3,224,224,F32" diff --git a/ffi/py/tract/nnef.py b/ffi/py/tract/nnef.py index 98676352f8..030b05345b 100644 --- a/ffi/py/tract/nnef.py +++ b/ffi/py/tract/nnef.py @@ -72,6 +72,14 @@ def with_pulse(self) -> "Nnef": check(lib.tract_nnef_enable_pulse(self.ptr)) return self + def with_extended_identifier_syntax(self) -> "Nnef": + """ + Enable tract-opl extensions to NNEF for extended identifiers (will support PyTorch 2 path-like ids) + """ + self._valid() + check(lib.tract_nnef_allow_extended_identifier_syntax(self.ptr, True)) + return self + def write_model_to_dir(self, model: Model, path: Union[str, Path]) -> None: """ Save `model` as a NNEF directory model in `path`. diff --git a/ffi/src/lib.rs b/ffi/src/lib.rs index a1b5a7ef9c..5d00ad7b9d 100644 --- a/ffi/src/lib.rs +++ b/ffi/src/lib.rs @@ -232,6 +232,15 @@ pub unsafe extern "C" fn tract_nnef_enable_pulse(nnef: *mut TractNnef) -> TRACT_ }) } +#[no_mangle] +pub unsafe extern "C" fn tract_nnef_allow_extended_identifier_syntax(nnef: *mut TractNnef, enable: bool) -> TRACT_RESULT { + wrap(|| unsafe { + check_not_null!(nnef); + (*nnef).0.allow_extended_identifier_syntax(enable); + Ok(()) + }) +} + /// Destroy the NNEF parser. It is safe to detroy the NNEF parser once the model had been loaded. #[no_mangle] pub unsafe extern "C" fn tract_nnef_destroy(nnef: *mut *mut TractNnef) -> TRACT_RESULT {