diff --git a/nnef/src/ast/dump_doc.rs b/nnef/src/ast/dump_doc.rs index 7483e4b33c..c103c65251 100644 --- a/nnef/src/ast/dump_doc.rs +++ b/nnef/src/ast/dump_doc.rs @@ -37,7 +37,7 @@ impl<'a> DocDumper<'a> { Dumper::new(&Nnef::default(), self.w).fragment_decl(&fragment_decl)?; } // Generate and write Primitive declarations. - for primitive in registry.primitives.values().flatten().sorted_by_key(|v| &v.decl.id) { + for primitive in registry.primitives.values().sorted_by_key(|v| &v.decl.id) { primitive.docstrings.iter().flatten() .try_for_each(|d| writeln!(self.w, "# {d}"))?; diff --git a/nnef/src/deser.rs b/nnef/src/deser.rs index 0e32d66b03..03c4cc4268 100644 --- a/nnef/src/deser.rs +++ b/nnef/src/deser.rs @@ -224,7 +224,9 @@ impl<'mb> ModelBuilder<'mb> { ); } } - for registry in &self.framework.registries { + + // We start with the registry that has been added last + for registry in self.framework.registries.iter().rev() { if self.registries.contains(®istry.id) { if let Some(outputs) = registry .deserialize(self, invocation, dt) diff --git a/nnef/src/framework.rs b/nnef/src/framework.rs index 4ee5edd766..582374c58d 100644 --- a/nnef/src/framework.rs +++ b/nnef/src/framework.rs @@ -57,20 +57,6 @@ impl Nnef { self } - pub fn with_primitive_alternative( - mut self, - registry_id: &str, - op_id: &str, - func: ToTract, - ) -> TractResult { - if let Some(reg) = self.registries.iter_mut().find(|it| it.id == registry_id.into()) { - reg.register_primitive_alternative(op_id, func).with_context(|| { - anyhow!("Impossible to add new primitive alternative for op {}", op_id) - })?; - } - Ok(self) - } - pub fn enable_tract_resource(&mut self) { self.registries.push(crate::ops::tract_core()); } diff --git a/nnef/src/registry.rs b/nnef/src/registry.rs index 654d365883..a6162684b1 100644 --- a/nnef/src/registry.rs +++ b/nnef/src/registry.rs @@ -41,7 +41,7 @@ pub struct Registry { pub docstrings: Option>, pub aliases: Vec, pub fragments: HashMap, - pub primitives: HashMap>, + pub primitives: HashMap, pub unit_element_wise_ops: Vec<(Identifier, Box)>, pub element_wise_ops: Vec<(Identifier, TypeId, FromTract, Vec, ToTract)>, pub binary_ops: Vec, @@ -89,36 +89,12 @@ impl Registry { results: results.iter().cloned().map(|it| it.into()).collect(), }; let primitive_decl = PrimitiveDecl { decl, docstrings: None, to_tract: func }; - self.primitives.insert(id.clone(), vec![primitive_decl]); + self.primitives.insert(id.clone(), primitive_decl); self.primitives .get_mut(&id) - .and_then(|it| it.last_mut()) .expect("Unexpected empty entry in primitives hashmap") } - pub fn register_primitive_alternative( - &mut self, - id: impl AsRef, - func: ToTract, - ) -> TractResult<&mut PrimitiveDecl> { - let id: Identifier = id.as_ref().into(); - self.primitives.get_mut(&id).map_or_else( - || bail!("No primitive with name '{}' in registry: {}", id.as_ref(), self.id.as_ref()), - |it| -> TractResult<()> { - let last = it.last().unwrap_or_else(|| panic!("Unexpected empty primitive declaration for '{}'", id.as_ref())); - let mut new = last.clone(); - new.to_tract = func; - it.insert(0, new); - Ok(()) - }, - )?; - Ok(self - .primitives - .get_mut(&id) - .and_then(|it| it.last_mut()) - .expect("Unexpected empty entry in primitives hashmap")) - } - pub fn register_fragment(&mut self, def: FragmentDef) { self.fragments.insert(def.decl.id.clone(), def); } @@ -189,28 +165,14 @@ impl Registry { invocation: &ast::Invocation, dt: &[Option], ) -> TractResult> { - if let Some(p) = self.primitives.get(&invocation.id) { - let out_value = p - .iter() - .enumerate() - .find_map(|(idx, op)| { - let resolved = ResolvedInvocation { - invocation, - default_params: &op.decl.parameters, - dt_from_quant_file: dt, - }; - (op.to_tract)(builder, &resolved) - .map_err(|err| { - log::debug!( - "Failed to load {:?} with deserializer {}: {:?}", - &invocation.id, - idx, - &err - ); - }) - .ok() - }) - .ok_or(anyhow!("No valid deserializer found for {:?}", &invocation.id))?; + if let Some(op) = self.primitives.get(&invocation.id) { + let resolved = ResolvedInvocation { + invocation, + default_params: &op.decl.parameters, + dt_from_quant_file: dt, + }; + let out_value = (op.to_tract)(builder, &resolved) + .with_context(|| format!("Deserializing op `{}'", invocation.id.0))?; return Ok(Some(out_value)); } if let Some(ew) = self.unit_element_wise_ops.iter().find(|ew| ew.0 == invocation.id) {