Skip to content

Commit

Permalink
Add GraphLoader and allow customizing input shapes.
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelPeeters committed Feb 20, 2024
1 parent b2bc825 commit c8080ad
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 42 deletions.
52 changes: 29 additions & 23 deletions kn-graph/src/onnx/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use crate::graph::{
pub use crate::graph::Graph;
use crate::onnx::external_data::ExternalDataLoader;
use crate::onnx::inputs::{Attributes, Inputs};
use crate::onnx::proto::{ModelProto, tensor_shape_proto, TensorProto, TypeProto};
use crate::onnx::proto::{ModelProto, TensorProto, TypeProto};
use crate::onnx::proto::tensor_proto::DataLocation;
use crate::onnx::proto::tensor_proto::DataType;
use crate::onnx::proto::tensor_shape_proto::dimension::Value as ProtoDimValue;
use crate::onnx::proto::tensor_shape_proto::dimension;
use crate::onnx::proto::type_proto::Value as ProtoTypeValue;
use crate::onnx::result::{Node, OnnxError, OnnxResult, UnwrapProto};
use crate::onnx::store::Store;
Expand All @@ -31,8 +31,16 @@ use crate::shape::{Shape, Size};
// things to grep for: unwrap|expect|assert|panic
// introduce two main error kinds: "bug in file" and "unsupported"

pub type InputShaper = dyn Fn(&[OnnxDimValue], &str, usize) -> Option<Shape>;

#[derive(Debug, Clone)]
pub enum OnnxDimValue {
Value(i64),
Param(String),
}

// we use &dyn to avoid duplicate codegen of this large and non-critical function
pub fn graph_from_onnx_bytes(buf: &[u8], external: &mut dyn ExternalDataLoader) -> OnnxResult<Graph> {
pub fn graph_from_onnx_bytes(buf: &[u8], external: &mut dyn ExternalDataLoader, input_shaper: &InputShaper) -> OnnxResult<Graph> {
let model = load_model_proto(buf);
let model_graph = model.graph.as_ref().unwrap_proto("model.graph")?;

Expand All @@ -45,20 +53,26 @@ pub fn graph_from_onnx_bytes(buf: &[u8], external: &mut dyn ExternalDataLoader)
nodes.define(&tensor.name, OnnxValue::Value(value))
}

// clear newly defined values so we don't attribute them to the first node
graph.take_new_values();

// load inputs
let mut real_input_index = 0;
for input in &model_graph.input {
// initializers are allowed to re-appear in the inputs, so we skip them the second time
if nodes.contains(&input.name) {
continue;
}

let (shape, dtype) = resolve_tensor_type(input.r#type.as_ref().unwrap_proto("input.type")?, &input.name)?;
let input_proto = input.r#type.as_ref().unwrap_proto("input.type")?;
let (shape, dtype) = resolve_tensor_type(input_proto, &input.name, real_input_index, input_shaper)?;
let value = graph.input(shape, dtype);
nodes.define(&input.name, OnnxValue::Value(value));

real_input_index += 1;
}

// clear newly defined values so we don't attribute them to the first node
let _ = graph.take_new_values();

// load nodes
for node_proto in &model_graph.node {
let node = Node {
name: node_proto.name.as_str(),
Expand Down Expand Up @@ -1083,24 +1097,28 @@ fn define_tensor_data(
Ok(value)
}

fn resolve_tensor_type(ty: &TypeProto, name: &str) -> OnnxResult<(Shape, DType)> {
fn resolve_tensor_type(ty: &TypeProto, name: &str, index: usize, input_shaper: &InputShaper) -> OnnxResult<(Shape, DType)> {
let value = ty.value.as_ref().expect("Value doesn't have type set");
let result = match value {
ProtoTypeValue::TensorType(tensor) => {
let data_type = DataType::try_from(tensor.elem_type).expect("Invalid data type");
let dtype = resolve_dtype(data_type, name)?;

let dims = tensor
.shape
.as_ref()
.expect("Tensor does not have shape set")
.dim
.iter()
.map(resolve_tensor_dim)
.map(|d| match *d.value.as_ref().expect("Missing value for dimension") {
dimension::Value::DimValue(value) => OnnxDimValue::Value(value),
dimension::Value::DimParam(ref param) => OnnxDimValue::Param(param.clone()),
})
.collect_vec();

let dtype = resolve_dtype(data_type, name)?;
let shape = input_shaper(&dims, name, index).ok_or_else(|| OnnxError::FailedToShapeInput(dims, name.to_owned(), index))?;

(Shape::new(dims), dtype)
(shape, dtype)
}
_ => panic!("Unsupported value kind {:?}", value),
};
Expand Down Expand Up @@ -1134,18 +1152,6 @@ fn resolve_dtype(data_type: DataType, node: &str) -> OnnxResult<DType> {
Ok(dtype)
}

fn resolve_tensor_dim(dim: &tensor_shape_proto::Dimension) -> Size {
let value = dim.value.as_ref().expect("Missing value for dimension");

match value {
&ProtoDimValue::DimValue(inner) => Size::fixed(inner as usize),
ProtoDimValue::DimParam(name) => {
assert_eq!(name, "batch_size");
Size::BATCH
}
}
}

fn abs_axis(axis: i64, rank: usize) -> usize {
if axis == i64::MAX {
rank
Expand Down
134 changes: 134 additions & 0 deletions kn-graph/src/onnx/loader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::ops::Deref;
use std::path::Path;

use crate::graph::Graph;
use crate::onnx::{InputShaper, OnnxDimValue};
use crate::onnx::external_data::{ExternalDataLoader, NoExternalData, PathExternalData};
use crate::onnx::load::graph_from_onnx_bytes;
use crate::onnx::result::{OnnxError, OnnxResult, ToOnnxLoadResult};
use crate::shape::{Shape, Size};

/// Load an [ONNX](https://github.com/onnx/onnx/blob/main/docs/IR.md) graph.
///
/// Many loading settings are customizable:
/// * the source, either from a path through [from_path] or from bytes through [from_bytes].
/// * whether [external data](https://github.com/onnx/onnx/blob/main/docs/ExternalData.md) is allowed,
/// through [from_path] `allow_external` or [set_external_data].
/// * input shape overrides (in order of priority):
/// * fully custom through [set_input_shaper_custom]
/// * specific input overrides through [input_shape_overrides]
/// * named axes through [add_named_axis]
///
/// A simple example:
/// ```no_run
/// # use kn_graph::graph::Graph;
/// # use kn_graph::onnx::GraphLoader;
/// # use kn_graph::shape;
/// # use kn_graph::shape::Size;
/// // load from a path, disallowing external data
/// let mut loader = GraphLoader::from_path("model.onnx", false).unwrap();
/// // set some named axes
/// loader.add_named_axis("batch_size", Size::BATCH);
/// loader.add_named_axis("sequence_length", Size::fixed(128));
/// // override the third input shape
/// loader.force_input_shapes(vec![None, None, Some(shape![1, Size::BATCH, 3])]);
/// // load the graph
/// let graph = loader.load().unwrap();
/// ```
#[allow(missing_debug_implementations)]
pub struct GraphLoader<'a> {
bytes: Cow<'a, [u8]>,
external: Box<dyn ExternalDataLoader>,

// input shape overrides
input_shaper_custom: Option<Box<InputShaper>>,
input_shape_overrides: Option<Vec<Option<Shape>>>,
named_axes: HashMap<String, Size>,
}

impl<'a> GraphLoader<'a> {
pub fn from_path(path: impl AsRef<Path>, allow_external: bool) -> OnnxResult<Self> {
let path = path.as_ref();
let bytes = std::fs::read(path).to_onnx_result(path)?;

let external: Box<dyn ExternalDataLoader> = if allow_external {
let parent = path
.parent()
.ok_or_else(|| OnnxError::MustHaveParentPath(path.to_owned()))?;
Box::new(PathExternalData(parent.to_owned()))
} else {
Box::new(NoExternalData)
};

Ok(GraphLoader {
bytes: Cow::Owned(bytes),
external,

input_shaper_custom: None,
input_shape_overrides: None,
named_axes: HashMap::new(),
})
}

pub fn from_bytes(bytes: &'a [u8]) -> Self {
GraphLoader {
bytes: Cow::Borrowed(bytes),
external: Box::new(NoExternalData),

input_shaper_custom: None,
input_shape_overrides: None,
named_axes: HashMap::new(),
}
}

pub fn set_external_data(&mut self, external: Box<dyn ExternalDataLoader>) {
self.external = external;
}

pub fn set_input_shaper_custom(&mut self, shaper: Box<InputShaper>) {
self.input_shaper_custom = Some(shaper);
}

pub fn force_input_shapes(&mut self, shapes: Vec<Option<Shape>>) {
self.input_shape_overrides = Some(shapes)
}

pub fn add_named_axis(&mut self, name: &str, value: Size) {
self.named_axes.insert(name.to_owned(), value);
}

pub fn load(self) -> OnnxResult<Graph> {
let mut external = self.external;

let input_shaper = move |dims: &[OnnxDimValue], name: &str, index| {
// first try custom shaper
if let Some(input_shaper_custom) = &self.input_shaper_custom {
return input_shaper_custom(dims, name, index);
}
// then shape overrides
if let Some(input_shape_overrides) = &self.input_shape_overrides {
if index < input_shape_overrides.len() {
if let Some(shape) = &input_shape_overrides[index] {
return Some(shape.clone());
}
} else {
return None;
}
}
// finally try basic resolution using named axes
let mut new_dims = vec![];
for d in dims {
let d_new = match *d {
OnnxDimValue::Value(value) => Size::fixed(value as usize),
OnnxDimValue::Param(ref param) => self.named_axes.get(param)?.clone(),
};
new_dims.push(d_new);
}
Some(Shape::new(new_dims))
};

graph_from_onnx_bytes(self.bytes.deref(), external.as_mut(), &input_shaper)
}
}
39 changes: 20 additions & 19 deletions kn-graph/src/onnx/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::path::Path;

use external_data::ExternalDataLoader;
pub use load::{InputShaper, OnnxDimValue};
pub use loader::GraphLoader;

use crate::graph::Graph;
use crate::onnx::external_data::{NoExternalData, PathExternalData};
use crate::onnx::load::graph_from_onnx_bytes;
use crate::onnx::result::{OnnxError, OnnxResult, ToOnnxLoadResult};
use crate::onnx::result::OnnxResult;
use crate::shape::Size;

pub mod external_data;
mod inputs;
Expand All @@ -15,35 +16,35 @@ mod proto;
pub mod result;
mod store;
mod typed_value;
mod loader;

/// Load an [ONNX](https://github.com/onnx/onnx/blob/main/docs/IR.md) file from the given path.
///
/// If `allow_external` is true, the onnx will be allowed to load external data files,
/// see [the spec](https://github.com/onnx/onnx/blob/main/docs/IR.md#external-tensor-data).
/// If `allow_external` is false and the ONNX file does reference external data, an error is returned.
///
/// For more flexibility, see [GraphLoader].
pub fn load_graph_from_onnx_path(path: impl AsRef<Path>, allow_external: bool) -> OnnxResult<Graph> {
let path = path.as_ref();
let buf = std::fs::read(path).to_onnx_result(path)?;

let mut external: Box<dyn ExternalDataLoader> = if allow_external {
let parent = path
.parent()
.ok_or_else(|| OnnxError::MustHaveParentPath(path.to_owned()))?;
Box::new(PathExternalData(parent.to_owned()))
} else {
Box::new(NoExternalData)
};

graph_from_onnx_bytes(&buf, &mut *external)
let mut loader = GraphLoader::from_path(path, allow_external)?;
loader.add_named_axis("batch_size", Size::BATCH);
loader.load()
}

/// Load an [ONNX](https://github.com/onnx/onnx/blob/main/docs/IR.md) file from the given bytes.
///
/// The file is not allowed to reference external data files.
///
/// For more flexibility, see [GraphLoader].
pub fn load_graph_from_onnx_bytes(buffer: &[u8]) -> OnnxResult<Graph> {
graph_from_onnx_bytes(buffer, &mut NoExternalData)
let mut loader = GraphLoader::from_bytes(buffer);
loader.add_named_axis("batch_size", Size::BATCH);
loader.load()
}

pub fn load_graph_from_onnx_bytes_custom(buffer: &[u8], external: &mut dyn ExternalDataLoader) -> OnnxResult<Graph> {
graph_from_onnx_bytes(buffer, external)
pub fn load_graph_from_onnx_bytes_custom(buffer: &[u8], external: Box<dyn ExternalDataLoader>) -> OnnxResult<Graph> {
let mut loader = GraphLoader::from_bytes(buffer);
loader.set_external_data(external);
loader.add_named_axis("batch_size", Size::BATCH);
loader.load()
}
2 changes: 2 additions & 0 deletions kn-graph/src/onnx/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt::{Display, Formatter};
use std::io;
use std::path::{Path, PathBuf};

use crate::onnx::load::OnnxDimValue;
use crate::onnx::proto::attribute_proto::AttributeType;
use crate::onnx::proto::tensor_proto::DataType;
use crate::onnx::typed_value::AsShapeError;
Expand All @@ -21,6 +22,7 @@ pub enum OnnxError {

NonNormalExternalDataPath(PathBuf),
MustHaveParentPath(PathBuf),
FailedToShapeInput(Vec<OnnxDimValue>, String, usize),

MissingProtoField(&'static str),

Expand Down

0 comments on commit c8080ad

Please sign in to comment.