Skip to content

Commit

Permalink
Remove primitive alternatives -> can be done through registry
Browse files Browse the repository at this point in the history
  • Loading branch information
emricksinisonos authored and kali committed Apr 19, 2023
1 parent 1fbf2f9 commit db4f3a9
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 64 deletions.
2 changes: 1 addition & 1 deletion nnef/src/ast/dump_doc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl<'a> DocDumper<'a> {
Dumper::new(&Nnef::default(), self.w).fragment_decl(&fragment_decl)?;
}
// Generate and write Primitive declarations.
for primitive in registry.primitives.values().flatten().sorted_by_key(|v| &v.decl.id) {
for primitive in registry.primitives.values().sorted_by_key(|v| &v.decl.id) {
primitive.docstrings.iter().flatten()
.try_for_each(|d| writeln!(self.w, "# {d}"))?;

Expand Down
4 changes: 3 additions & 1 deletion nnef/src/deser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ impl<'mb> ModelBuilder<'mb> {
);
}
}
for registry in &self.framework.registries {

// We start with the registry that has been added last
for registry in self.framework.registries.iter().rev() {
if self.registries.contains(&registry.id) {
if let Some(outputs) = registry
.deserialize(self, invocation, dt)
Expand Down
14 changes: 0 additions & 14 deletions nnef/src/framework.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,6 @@ impl Nnef {
self
}

pub fn with_primitive_alternative(
mut self,
registry_id: &str,
op_id: &str,
func: ToTract,
) -> TractResult<Self> {
if let Some(reg) = self.registries.iter_mut().find(|it| it.id == registry_id.into()) {
reg.register_primitive_alternative(op_id, func).with_context(|| {
anyhow!("Impossible to add new primitive alternative for op {}", op_id)
})?;
}
Ok(self)
}

pub fn enable_tract_resource(&mut self) {
self.registries.push(crate::ops::tract_core());
}
Expand Down
58 changes: 10 additions & 48 deletions nnef/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct Registry {
pub docstrings: Option<Vec<String>>,
pub aliases: Vec<Identifier>,
pub fragments: HashMap<Identifier, FragmentDef>,
pub primitives: HashMap<Identifier, Vec<PrimitiveDecl>>,
pub primitives: HashMap<Identifier, PrimitiveDecl>,
pub unit_element_wise_ops: Vec<(Identifier, Box<dyn ElementWiseMiniOp>)>,
pub element_wise_ops: Vec<(Identifier, TypeId, FromTract, Vec<ast::Parameter>, ToTract)>,
pub binary_ops: Vec<BinOp>,
Expand Down Expand Up @@ -89,36 +89,12 @@ impl Registry {
results: results.iter().cloned().map(|it| it.into()).collect(),
};
let primitive_decl = PrimitiveDecl { decl, docstrings: None, to_tract: func };
self.primitives.insert(id.clone(), vec![primitive_decl]);
self.primitives.insert(id.clone(), primitive_decl);
self.primitives
.get_mut(&id)
.and_then(|it| it.last_mut())
.expect("Unexpected empty entry in primitives hashmap")
}

pub fn register_primitive_alternative(
&mut self,
id: impl AsRef<str>,
func: ToTract,
) -> TractResult<&mut PrimitiveDecl> {
let id: Identifier = id.as_ref().into();
self.primitives.get_mut(&id).map_or_else(
|| bail!("No primitive with name '{}' in registry: {}", id.as_ref(), self.id.as_ref()),
|it| -> TractResult<()> {
let last = it.last().unwrap_or_else(|| panic!("Unexpected empty primitive declaration for '{}'", id.as_ref()));
let mut new = last.clone();
new.to_tract = func;
it.insert(0, new);
Ok(())
},
)?;
Ok(self
.primitives
.get_mut(&id)
.and_then(|it| it.last_mut())
.expect("Unexpected empty entry in primitives hashmap"))
}

pub fn register_fragment(&mut self, def: FragmentDef) {
self.fragments.insert(def.decl.id.clone(), def);
}
Expand Down Expand Up @@ -189,28 +165,14 @@ impl Registry {
invocation: &ast::Invocation,
dt: &[Option<DatumType>],
) -> TractResult<Option<Value>> {
if let Some(p) = self.primitives.get(&invocation.id) {
let out_value = p
.iter()
.enumerate()
.find_map(|(idx, op)| {
let resolved = ResolvedInvocation {
invocation,
default_params: &op.decl.parameters,
dt_from_quant_file: dt,
};
(op.to_tract)(builder, &resolved)
.map_err(|err| {
log::debug!(
"Failed to load {:?} with deserializer {}: {:?}",
&invocation.id,
idx,
&err
);
})
.ok()
})
.ok_or(anyhow!("No valid deserializer found for {:?}", &invocation.id))?;
if let Some(op) = self.primitives.get(&invocation.id) {
let resolved = ResolvedInvocation {
invocation,
default_params: &op.decl.parameters,
dt_from_quant_file: dt,
};
let out_value = (op.to_tract)(builder, &resolved)
.with_context(|| format!("Deserializing op `{}'", invocation.id.0))?;
return Ok(Some(out_value));
}
if let Some(ew) = self.unit_element_wise_ops.iter().find(|ew| ew.0 == invocation.id) {
Expand Down

0 comments on commit db4f3a9

Please sign in to comment.