Question about Pulse #509
Replies: 21 comments 1 reply
-
That's exactly what tract-pulse is about. It works for convolutional and recurring nets, but the API is a bit... well, convoluted, so I have not documented it as I consider the API semi experimental (the code does work well, it's critical for our usage). I will help. You can "infer" a bit of the proces by looking at The gist of if:
The command line can be helpful too (try Try it, show me where you get stuck. I will help :) |
Beta Was this translation helpful? Give feedback.
-
Thanks for your explanation and pointers! I will probably need some time to get things rolling, but will come back eventually. |
Beta Was this translation helpful? Give feedback.
-
Another question about the delay: |
Beta Was this translation helpful? Give feedback.
-
Tell me if this helps: 1D convolution, kernel of 3, valid:
123 represent the application of the convolution kernel tract delay: 2, first two frame outputed have undefined value. 1D convolution, kernel of 3, left padded to same size:
Here, 0 is the padding value, the first actual sample fed to tract is still 1. tract reports a delay of 0 |
Beta Was this translation helpful? Give feedback.
-
Yes, I mean the second example. Can the left padding to same size be specified per convolution, or for the entire model or input? |
Beta Was this translation helpful? Give feedback.
-
Well, you have both options. ONNX will let you provide explicit padding for each convolution and put it all at the left (aka "before"), mimicking the second case. Alternatively, you could declare the convolutions VALID in the ONNX model and feed the network enough 0 to flush the buffers out of their undefined values before starting pushing the actual signal. So that would be:
tract would still report a delay of 2, you need to feed two rounds of zero and skip their matching output. Then you can feed your signal and get the expected result. |
Beta Was this translation helpful? Give feedback.
-
I'm having some trouble with the streaming inputs. import onnx
m = onnx.load("enc.onnx")
print(m.graph.input[0])
print(m.graph.input[1])
I tried the following so far: use tract_onnx::{
prelude::{tvec, Datum, Framework, InferenceFact, InferenceModelExt, TractResult},
tract_hir::shapefactoid,
};
use tract_pulse::internal::ToDim;
use tract_pulse::model::*;
fn main() -> TractResult<()> {
let s = tract_pulse::fact::stream_dim();
let erb_feat = InferenceFact::dt_shape(f32::datum_type(), shapefactoid!(1, 1, s, 16));
let hemb = InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 1, 256));
let mut enc = tract_onnx::onnx()
.model_for_path("../out/enc.onnx")?
.with_input_fact(0, erb_feat)?
.with_input_fact(1, hemb)?
.with_input_names(&["erb_feat", "h0emb"])?;
enc.analyse(true)?;
let enc = enc.into_typed()?;
let enc = enc.declutter()?;
let pulsed = PulsedModel::new(&enc, 1)?;
Ok(())
} Resulting in the following error: Error: Translating node #13 "h0emb" Source Pulsifier(1)
Caused by:
Can not pulse a tensor with no streaming dim
Stack backtrace:
0: anyhow::private::new_adhoc
at /home/hendrik/.cargo/registry/src/github.com-1ecc6299db9ec823/anyhow-1.0.41/src/lib.rs:632:36
1: tract_pulse::fact::PulsedFact::from_tensor_fact_pulse::{{closure}}
at /home/hendrik/.cargo/registry/src/github.com-1ecc6299db9ec823/tract-pulse-0.14.2/src/fact.rs:51:28
2: core::option::Option<T>::ok_or_else
at /rustc/24bdc6d73a75dce9a7013ebc7c037013ff4ea099/library/core/src/option.rs:595:25
3: tract_pulse::fact::PulsedFact::from_tensor_fact_pulse
at /home/hendrik/.cargo/registry/src/github.com-1ecc6299db9ec823/tract-pulse-0.14.2/src/fact.rs:48:27
4: tract_pulse::ops::source::pulsify
at /home/hendrik/.cargo/registry/src/github.com-1ecc6299db9ec823/tract-pulse-0.14.2/src/ops/source.rs:14:23
5: tract_pulse::ops::source::register_all::{{closure}}
at /home/hendrik/.cargo/registry/src/github.com-1ecc6299db9ec823/tract-pulse-0.14.2/src/macros.rs:35:25
6: core::ops::function::FnOnce::call_once
at /rustc/24bdc6d73a75dce9a7013ebc7c037013ff4ea099/library/core/src/ops/function.rs:227:5
7: <tract_pulse::model::Pulsifier as tract_core::model::translator::Translate<tract_core::model::fact::TypedFact,alloc::boxed::Box<dyn tract_core::ops::TypedOp>,tract_pulse::fact::PulsedFact,alloc::boxed::Box<dyn tract_pulse::ops::PulsedOp>>>::translate_node
at /home/hendrik/.cargo/registry/src/github.com-1ecc6299db9ec823/tract-pulse-0.14.2/src/model.rs:104:13
8: tract_core::model::translator::Translate::translate_model_with_mappings
at /home/hendrik/.cargo/registry/src/github.com-1ecc6299db9ec823/tract-core-0.14.2/src/model/translator.rs:35:27
9: <tract_core::model::graph::Graph<tract_pulse::fact::PulsedFact,alloc::boxed::Box<dyn tract_pulse::ops::PulsedOp>> as tract_pulse::model::PulsedModelExt>::new_with_mapping
at /home/hendrik/.cargo/registry/src/github.com-1ecc6299db9ec823/tract-pulse-0.14.2/src/model.rs:28:9
10: <tract_core::model::graph::Graph<tract_pulse::fact::PulsedFact,alloc::boxed::Box<dyn tract_pulse::ops::PulsedOp>> as tract_pulse::model::PulsedModelExt>::new
at /home/hendrik/.cargo/registry/src/github.com-1ecc6299db9ec823/tract-pulse-0.14.2/src/model.rs:20:12
11: clc_tract::main
at /home/hendrik/projects/clcrs/clc-tract/src/main.rs:20:18
12: core::ops::function::FnOnce::call_once
at /rustc/24bdc6d73a75dce9a7013ebc7c037013ff4ea099/library/core/src/ops/function.rs:227:5
13: std::sys_common::backtrace::__rust_begin_short_backtrace
at /rustc/24bdc6d73a75dce9a7013ebc7c037013ff4ea099/library/std/src/sys_common/backtrace.rs:125:18
14: std::rt::lang_start::{{closure}}
at /rustc/24bdc6d73a75dce9a7013ebc7c037013ff4ea099/library/std/src/rt.rs:49:18
15: core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once
at /rustc/24bdc6d73a75dce9a7013ebc7c037013ff4ea099/library/core/src/ops/function.rs:259:13
std::panicking::try::do_call
at /rustc/24bdc6d73a75dce9a7013ebc7c037013ff4ea099/library/std/src/panicking.rs:401:40
std::panicking::try
at /rustc/24bdc6d73a75dce9a7013ebc7c037013ff4ea099/library/std/src/panicking.rs:365:19
std::panic::catch_unwind
at /rustc/24bdc6d73a75dce9a7013ebc7c037013ff4ea099/library/std/src/panic.rs:434:14
std::rt::lang_start_internal
at /rustc/24bdc6d73a75dce9a7013ebc7c037013ff4ea099/library/std/src/rt.rs:34:21
16: std::rt::lang_start
at /rustc/24bdc6d73a75dce9a7013ebc7c037013ff4ea099/library/std/src/rt.rs:48:5
17: main
18: __libc_start_main
19: _start |
Beta Was this translation helpful? Give feedback.
-
I think it is a bug. Can you give me the model ? you can use @.fr if you prefer to share by email. |
Beta Was this translation helpful? Give feedback.
-
2This allows to reproduce the error: import torch
from torch import nn
import onnx
layer_count = 4
model = nn.LSTM(10, 20, num_layers=layer_count, bidirectional=True)
model.eval()
with torch.no_grad():
input = torch.randn(5, 3, 10)
h0 = torch.randn(layer_count * 2, 3, 20)
c0 = torch.randn(layer_count * 2, 3, 20)
output, (hn, cn) = model(input, (h0, c0))
# default export
torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx')
onnx_model = onnx.load('lstm.onnx')
# input shape [5, 3, 10]
print(onnx_model.graph.input[0])
# export with `dynamic_axes`
torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx',
input_names=['input', 'h0', 'c0'],
output_names=['output', 'hn', 'cn'],
dynamic_axes={'input': {0: 'sequence'}, 'output': {0: 'sequence'}})
onnx_model = onnx.load('lstm.onnx')
# input shape ['sequence', 3, 10]
print(onnx_model.graph.input[0])
print(onnx_model.graph.input[1])
print(onnx_model.graph.input[2]) use tract_onnx::{
prelude::{tvec, Datum, Framework, InferenceFact, InferenceModelExt, TractResult},
tract_hir::shapefactoid,
};
use tract_pulse::internal::ToDim;
use tract_pulse::model::*;
fn main() -> TractResult<()> {
let s = tract_pulse::fact::stream_dim();
let input = InferenceFact::dt_shape(f32::datum_type(), shapefactoid!(s, 3, 10));
let h0 = InferenceFact::dt_shape(f32::datum_type(), tvec!(4*2, 3, 20));
let c0 = InferenceFact::dt_shape(f32::datum_type(), tvec!(4*2, 3, 20));
let mut enc = tract_onnx::onnx()
.model_for_path("lstm.onnx")?
.with_input_fact(0, input)?
.with_input_fact(1, h0)?
.with_input_fact(2, c0)?
.with_input_names(&["input", "h0", "c0"])?;
enc.analyse(true)?;
let enc = enc.into_typed()?;
let enc = enc.declutter()?;
let pulsed = PulsedModel::new(&enc, 1)?;
Ok(())
} |
Beta Was this translation helpful? Give feedback.
-
Ok, there may be a bug, at least a not-so-helpful error message. But there will be a deeper issue here: you can not use pulses with bidirectional LSTM as the output is dependent on inputs that will only be given to the network in subsequent pulses. The issue you observe is still present with simple causal LSTM layers, so I'm going to investigate it. |
Beta Was this translation helpful? Give feedback.
-
Next issue... the error message is still not great, there is still an opportunity to get more robust, but you can fix it by just giving a value to h0 and c0 (hopefully this makes sense in your case). That way they will get absorbed early by the Scan op that implements LSTMs, and it will deal with the pulsing aspect. If you needd to delay providing h0 and c0, then we need some more work here.
But we have one more blocking one:
I think this is just basically a TODO :) having a look. |
Beta Was this translation helpful? Give feedback.
-
I get the same error with unidirectional lstms. |
Beta Was this translation helpful? Give feedback.
-
Yep, I'm trying to fix these. |
Beta Was this translation helpful? Give feedback.
-
Another point which might be also related: RNN-type models with an output state (i.e.
Theoretically, the input and output states that do not have a pulse dimension could also be fully handled by tract-pulse. I.e. the model would need to initialize the first hidden state itself and pulse makes sure that the state is passed to the network on every call. |
Beta Was this translation helpful? Give feedback.
-
Here we go. I knew this was going to be terrible as this API and its kirks is obviously not ready for generic consumption... You may need 0.15.0 . use tract_onnx::prelude::*;
use tract_onnx::tract_hir::shapefactoid;
use tract_pulse::internal::ToDim;
use tract_pulse::model::*;
fn constantize_input(
model: &mut InferenceModel,
name: &str,
value: Arc<Tensor>,
) -> TractResult<()> {
let node_id = model.node_by_name(name)?.id;
model.node_mut(node_id).op =
tract_onnx::tract_core::ops::konst::Const::new(value.clone()).into();
model.node_mut(node_id).outputs[0].fact = InferenceFact::from(value);
Ok(())
}
fn main() -> TractResult<()> {
let s = tract_pulse::fact::stream_dim();
let input = InferenceFact::dt_shape(f32::datum_type(), shapefactoid!(s, 3, 10));
let mut enc = tract_onnx::onnx().model_for_path("lstm.onnx")?.with_input_fact(0, input)?;
// need to erase h0 and c0 from network interface so pulse and scan do their thing
// first the state inputs are made into constant
let h0 = Tensor::zero::<f32>(&[4 * 2, 3, 20])?;
let c0 = Tensor::zero::<f32>(&[4 * 2, 3, 20])?;
constantize_input(&mut enc, "h0", h0.into())?;
constantize_input(&mut enc, "c0", c0.into())?;
// then we erase the state vectors outputs from the interface as tract will deal with them
enc = enc.with_input_names(&["input"])?.with_output_names(&["Squeeze_15"])?;
// bring network to typed and declutter form
enc.analyse(true)?;
let enc = enc.into_typed()?;
let enc = enc.declutter()?;
// transform to pulse form
let pulsed = PulsedModel::new(&enc, 1)?;
let delay = pulsed.output_fact(0)?.delay;
assert_eq!(delay, 0); // no convo in the net, lstm do not introduce delay
// then back to typed, then optimized and runnable form
let model = pulsed.into_typed()?.into_optimized()?.into_runnable()?;
// we can not use model.run, we have to do a bit explicit state management here
let mut state = SimpleState::new(model)?;
// loop over input stream
for v in 0..4 {
// mocked input chunks
let input = tensor0(v as f32).broadcast_scalar_to_shape(&[1, 3, 10])?;
let output = state.run(tvec!(input.into()))?;
eprintln!("{:?}", output);
}
Ok(())
} |
Beta Was this translation helpful? Give feedback.
-
So... yeah, basically tract already does what you were suggesting, but the onnx export has its own state management that gets in the way, so we need to kill a part of the network (state maintenance input and outputs) to make the pulsification happy. |
Beta Was this translation helpful? Give feedback.
-
Awesome, thanks for your help! Maybe it makes sense to include this in the examples folder? |
Beta Was this translation helpful? Give feedback.
-
I don't think it's ready for making an example. How about making this issue a conversation instead of closing it, so people may be able to use it as a reference ? |
Beta Was this translation helpful? Give feedback.
-
Another question (or possible bug) about the delay: Given the following ConvLstm with 3x1 time-aligened kernels (3 is time axis, 1 if feature axis) via 0-padding before the signal:
The tract-pulse delay does not seem to consider the first padding and always reports a delay corresponding to the kernel size of
To reproduce: import torch
from torch import nn
import onnx
bs = 1
layer_count = 4
lstm_in = 80
lstm_h = 20
def convt1(in_ch, out_ch, t=1, lookahead=0):
"""Time aligned conv with kernel size (t x 1)"""
return nn.Sequential(
nn.ConstantPad2d((0, 0, t - 1 - lookahead, lookahead), 0.0),
nn.Conv2d(in_ch, out_ch, kernel_size=(t, 1)),
)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv0 = convt1(1, 4, t=3, lookahead=0) # Changing lookahead here does not change the reported tract-pulse delay
self.conv1 = convt1(4, 8, t=3, lookahead=0) # Changing lookahead here results in correct tract-pulse delay calculation
self.lstm = nn.LSTM(lstm_in, lstm_h, num_layers=layer_count, batch_first=True)
def forward(self, x, h):
# input shape: [B, C, T, F]
x = self.conv0(x) # [B, 4, T, 10]
x = self.conv1(x) # [B, 8, T, 10]
x = x.transpose(1, 2).reshape(bs, -1, lstm_in) # [B, T, 8*10]
return self.lstm(x, h)
model = Model()
model.eval()
print(model)
with torch.no_grad():
input = torch.randn(1, 1, 5, 10)
h0 = torch.randn(layer_count, 1, lstm_h)
c0 = torch.randn(layer_count, 1, lstm_h)
output, (hn, cn) = model(input, (h0, c0))
# default export
torch.onnx.export(model, (input, (h0, c0)), "lstm.onnx")
onnx_model = onnx.load("lstm.onnx")
# export with `dynamic_axes`
torch.onnx.export(
model,
(input, (h0, c0)),
"lstm.onnx",
input_names=["input", "h0", "c0"],
output_names=["output", "hn", "cn"],
dynamic_axes={"input": {0: "sequence"}, "output": {0: "sequence"}},
)
onnx_model = onnx.load("lstm.onnx") use tract_onnx::prelude::*;
use tract_onnx::tract_hir::shapefactoid;
use tract_pulse::internal::ToDim;
use tract_pulse::model::*;
fn constantize_input(
model: &mut InferenceModel,
name: &str,
value: Arc<Tensor>,
) -> TractResult<()> {
let node_id = model.node_by_name(name)?.id;
model.node_mut(node_id).op =
tract_onnx::tract_core::ops::konst::Const::new(value.clone()).into();
model.node_mut(node_id).outputs[0].fact = InferenceFact::from(value);
Ok(())
}
fn main() -> TractResult<()> {
let s = tract_pulse::fact::stream_dim();
let input = InferenceFact::dt_shape(f32::datum_type(), shapefactoid!(1, 1, s, 10));
let mut enc = tract_onnx::onnx().model_for_path("lstm.onnx")?.with_input_fact(0, input)?;
// need to erase h0 and c0 from network interface so pulse and scan do their thing
// first the state inputs are made into constant
let h0 = Tensor::zero::<f32>(&[4, 1, 20])?;
let c0 = Tensor::zero::<f32>(&[4, 1, 20])?;
constantize_input(&mut enc, "h0", h0.into())?;
constantize_input(&mut enc, "c0", c0.into())?;
// then we erase the state vectors outputs from the interface as tract will deal with them
enc = enc.with_input_names(&["input"])?.with_output_names(&["Squeeze_15"])?;
// bring network to typed and declutter form
enc.analyse(true)?;
let enc = enc.into_typed()?;
let enc = enc.declutter()?;
// transform to pulse form
let pulsed = PulsedModel::new(&enc, 1)?;
let delay = pulsed.output_fact(0)?.delay;
dbg!(delay);
//assert_eq!(delay, 0); // TODO: convs should be time aligned, lstm do not introduce delay
// then back to typed, then optimized and runnable form
let model = pulsed.into_typed()?.into_optimized()?.into_runnable()?;
// we can not use model.run, we have to do a bit explicit state management here
let mut state = SimpleState::new(model)?;
// loop over input stream
for v in 0..4 {
// mocked input chunks
let input = tensor0(v as f32).broadcast_scalar_to_shape(&[1, 1, 1, 10])?;
let output = state.run(tvec!(input.into()))?;
eprintln!("{:?}", output);
}
Ok(())
} |
Beta Was this translation helpful? Give feedback.
-
Another question about the reported tract delay: I am exporting a delay module here: import torch.nn.functional as F
def df_delay_spec(spec: Tensor, delay: int) -> Tensor:
# spec (real) [B, 1, T, F', 2]
return F.pad(spec, (0, 0, 0, 0, 0, delay))
Tract on the other hand does not report any delay independent of the padding amount:
Edit: I would have expected it to report the same delay that we (right-)padded. Only left-padding should result in no delay. Edit2: Or maybe this is an issue with tract<->torch.nn.functional.Pad which does not result in the correct delay calculation as reported earlier. Here, the input was left-padded using a separate padding function and should not result in any delay. |
Beta Was this translation helpful? Give feedback.
-
The delay is the number of "unitialized" frames (unit of the time axis) that the pulse network will output and that you need to skip when reading the beginning of the output. Adding padding to the right can be done without delaying the data in the pulse wave: as a consequence, the op will not introduce a delay, just make its output longer.
0 are pads, * are invalid data. Note that with left-pad, we do not introduce a pulse delay, even if the actual signal is delayed. The left pad on the time axis is a special implementation, with a memory buffer. |
Beta Was this translation helpful? Give feedback.
-
Hi there,
I work with sequential data and need an output from the model at every time step. When using RNN like models, this is straight forward; I can pass an input of [B,T,F] to ONNX, where T=1, for simplicity batch B=1, F are the features.
For Convolutions this is a little more complicated, if the convolution kernel covers multiple time steps (let's say 3 in a 3x1xC_in x C_out kernel).
During inference, we need a buffer for each convolution that holds the convolution input of of size 3x1xC_in, to be able to run the model in a loop over time steps.
From looking on the code, the Pulse module could be supposed to cover this use case. Is this correct? I haven't found any example or documentation on this.
If not, the buffer handling could obviously also be done outside of tract, by providing an input for each convolution or splitting the convolutions into different models.
Beta Was this translation helpful? Give feedback.
All reactions