-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add GraphLoader and allow customizing input shapes.
- Loading branch information
1 parent
b2bc825
commit c8080ad
Showing
4 changed files
with
185 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
use std::borrow::Cow; | ||
use std::collections::HashMap; | ||
use std::ops::Deref; | ||
use std::path::Path; | ||
|
||
use crate::graph::Graph; | ||
use crate::onnx::{InputShaper, OnnxDimValue}; | ||
use crate::onnx::external_data::{ExternalDataLoader, NoExternalData, PathExternalData}; | ||
use crate::onnx::load::graph_from_onnx_bytes; | ||
use crate::onnx::result::{OnnxError, OnnxResult, ToOnnxLoadResult}; | ||
use crate::shape::{Shape, Size}; | ||
|
||
/// Load an [ONNX](https://github.com/onnx/onnx/blob/main/docs/IR.md) graph. | ||
/// | ||
/// Many loading settings are customizable: | ||
/// * the source, either from a path through [from_path] or from bytes through [from_bytes]. | ||
/// * whether [external data](https://github.com/onnx/onnx/blob/main/docs/ExternalData.md) is allowed, | ||
/// through [from_path] `allow_external` or [set_external_data]. | ||
/// * input shape overrides (in order of priority): | ||
/// * fully custom through [set_input_shaper_custom] | ||
/// * specific input overrides through [input_shape_overrides] | ||
/// * named axes through [add_named_axis] | ||
/// | ||
/// A simple example: | ||
/// ```no_run | ||
/// # use kn_graph::graph::Graph; | ||
/// # use kn_graph::onnx::GraphLoader; | ||
/// # use kn_graph::shape; | ||
/// # use kn_graph::shape::Size; | ||
/// // load from a path, disallowing external data | ||
/// let mut loader = GraphLoader::from_path("model.onnx", false).unwrap(); | ||
/// // set some named axes | ||
/// loader.add_named_axis("batch_size", Size::BATCH); | ||
/// loader.add_named_axis("sequence_length", Size::fixed(128)); | ||
/// // override the third input shape | ||
/// loader.force_input_shapes(vec![None, None, Some(shape![1, Size::BATCH, 3])]); | ||
/// // load the graph | ||
/// let graph = loader.load().unwrap(); | ||
/// ``` | ||
#[allow(missing_debug_implementations)] | ||
pub struct GraphLoader<'a> { | ||
bytes: Cow<'a, [u8]>, | ||
external: Box<dyn ExternalDataLoader>, | ||
|
||
// input shape overrides | ||
input_shaper_custom: Option<Box<InputShaper>>, | ||
input_shape_overrides: Option<Vec<Option<Shape>>>, | ||
named_axes: HashMap<String, Size>, | ||
} | ||
|
||
impl<'a> GraphLoader<'a> { | ||
pub fn from_path(path: impl AsRef<Path>, allow_external: bool) -> OnnxResult<Self> { | ||
let path = path.as_ref(); | ||
let bytes = std::fs::read(path).to_onnx_result(path)?; | ||
|
||
let external: Box<dyn ExternalDataLoader> = if allow_external { | ||
let parent = path | ||
.parent() | ||
.ok_or_else(|| OnnxError::MustHaveParentPath(path.to_owned()))?; | ||
Box::new(PathExternalData(parent.to_owned())) | ||
} else { | ||
Box::new(NoExternalData) | ||
}; | ||
|
||
Ok(GraphLoader { | ||
bytes: Cow::Owned(bytes), | ||
external, | ||
|
||
input_shaper_custom: None, | ||
input_shape_overrides: None, | ||
named_axes: HashMap::new(), | ||
}) | ||
} | ||
|
||
pub fn from_bytes(bytes: &'a [u8]) -> Self { | ||
GraphLoader { | ||
bytes: Cow::Borrowed(bytes), | ||
external: Box::new(NoExternalData), | ||
|
||
input_shaper_custom: None, | ||
input_shape_overrides: None, | ||
named_axes: HashMap::new(), | ||
} | ||
} | ||
|
||
pub fn set_external_data(&mut self, external: Box<dyn ExternalDataLoader>) { | ||
self.external = external; | ||
} | ||
|
||
pub fn set_input_shaper_custom(&mut self, shaper: Box<InputShaper>) { | ||
self.input_shaper_custom = Some(shaper); | ||
} | ||
|
||
pub fn force_input_shapes(&mut self, shapes: Vec<Option<Shape>>) { | ||
self.input_shape_overrides = Some(shapes) | ||
} | ||
|
||
pub fn add_named_axis(&mut self, name: &str, value: Size) { | ||
self.named_axes.insert(name.to_owned(), value); | ||
} | ||
|
||
pub fn load(self) -> OnnxResult<Graph> { | ||
let mut external = self.external; | ||
|
||
let input_shaper = move |dims: &[OnnxDimValue], name: &str, index| { | ||
// first try custom shaper | ||
if let Some(input_shaper_custom) = &self.input_shaper_custom { | ||
return input_shaper_custom(dims, name, index); | ||
} | ||
// then shape overrides | ||
if let Some(input_shape_overrides) = &self.input_shape_overrides { | ||
if index < input_shape_overrides.len() { | ||
if let Some(shape) = &input_shape_overrides[index] { | ||
return Some(shape.clone()); | ||
} | ||
} else { | ||
return None; | ||
} | ||
} | ||
// finally try basic resolution using named axes | ||
let mut new_dims = vec![]; | ||
for d in dims { | ||
let d_new = match *d { | ||
OnnxDimValue::Value(value) => Size::fixed(value as usize), | ||
OnnxDimValue::Param(ref param) => self.named_axes.get(param)?.clone(), | ||
}; | ||
new_dims.push(d_new); | ||
} | ||
Some(Shape::new(new_dims)) | ||
}; | ||
|
||
graph_from_onnx_bytes(self.bytes.deref(), external.as_mut(), &input_shaper) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters