Skip to content

Commit

Permalink
&[usize] -> TVec<usize> all the shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Apr 26, 2024
1 parent 9068b53 commit 7662183
Show file tree
Hide file tree
Showing 76 changed files with 325 additions and 342 deletions.
2 changes: 1 addition & 1 deletion api/rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ impl ValueInterface for Value {
let dt = to_internal_dt(dt);
let len = shape.iter().product::<usize>() * dt.size_of();
anyhow::ensure!(len == data.len());
let tensor = unsafe { Tensor::from_raw_dt(dt, shape, data)? };
let tensor = unsafe { Tensor::from_raw_dt(dt, shape.into(), data)? };
Ok(Value(tensor.into_tvalue()))
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl EvalOp for MultiBroadcastTo {
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let shape = self.shape.eval_to_usize(&session.resolved_symbols)?;
Ok(tvec!(inputs[0].broadcast_to_shape(&shape)?.into_tvalue()))
Ok(tvec!(inputs[0].broadcast_to_shape(shape.into_owned())?.into_tvalue()))
}
}

Expand Down
4 changes: 2 additions & 2 deletions core/src/ops/array/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ impl Gather {
unsafe fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<TValue> {
let data_view = data.to_array_view_unchecked::<T>();
let indices = indices.to_array_view::<i64>()?;
let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
let output_shape = self.compute_output_shape(data.shape(), indices.shape())?;
let mut output = Tensor::uninitialized::<T>(output_shape)?;
let mut output_view = output.to_array_view_mut::<T>()?;
for coords in tract_ndarray::indices(output_shape) {
for coords in tract_ndarray::indices(output_view.shape()) {
let ocoords = coords.as_array_view();
let ocoords = ocoords.as_slice().unwrap();
let mut icoords: TVec<usize> = ocoords[0..self.axis].into();
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/gather_nd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl EvalOp for GatherNd {
let indices = indices.cast_to::<i32>()?;
let indices = indices.to_array_view::<i32>()?;
unsafe {
let mut output = Tensor::uninitialized_dt(data.datum_type(), &shape)?;
let mut output = Tensor::uninitialized_dt(data.datum_type(), shape)?;
dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
self,
&mut output,
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/one_hot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl EvalOp for OneHot {
let mut shape: TVec<usize> = input.shape().into();
shape.insert(self.axis, self.dim);
unsafe {
let mut output = self.off.broadcast_scalar_to_shape(&shape)?;
let mut output = self.off.broadcast_scalar_to_shape(shape)?;
dispatch_datum_by_size!(Self::eval_t(self.off.datum_type())(
self,
&input,
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl Range {
len: usize,
) -> TractResult<Tensor> {
unsafe {
let mut result = Tensor::uninitialized::<T>(&[len])?;
let mut result = Tensor::uninitialized::<T>(tvec!(len))?;
let mut v = start.to_scalar::<T>()?.clone();
let step = step.to_scalar::<T>()?;
for i in 0..len {
Expand Down
4 changes: 1 addition & 3 deletions core/src/ops/array/reshape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ impl Op for FiniteReshape {
op_as_typed_op!();
}



impl EvalOp for FiniteReshape {
fn is_stateless(&self) -> bool {
true
Expand All @@ -29,7 +27,7 @@ impl EvalOp for FiniteReshape {
let input = args_1!(inputs);
let mut tensor = input.into_tensor();
unsafe {
tensor.set_shape_unchecked(&self.shape);
tensor.set_shape_unchecked(self.shape.clone());
}
Ok(tvec!(tensor.into_tvalue()))
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ fn eval_slice(input: &Tensor, axis: usize, start: usize, end: usize) -> TractRes
unsafe {
let mut shape: TVec<_> = input.shape().into();
shape[axis] = end - start;
let mut tensor = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
let mut tensor = Tensor::uninitialized_dt(input.datum_type(), shape)?;
tensor.assign_slice_unchecked(.., input, start..end, axis);
Ok(tvec!(tensor.into_tvalue()))
}
Expand Down
6 changes: 3 additions & 3 deletions core/src/ops/array/topk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ impl EvalOp for Topk {
let k = k.cast_to_scalar::<i64>()? as usize;
output_shape[self.axis] = k;
let dt = input.datum_type();
let mut output_values = Tensor::zero_dt(dt, &output_shape)?;
let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
let mut iterating_shape = output_shape.clone();
let mut output_values = Tensor::zero_dt(dt, output_shape.clone())?;
let mut output_indices = Tensor::zero::<i64>(output_shape.clone())?;
let mut iterating_shape = output_shape;
iterating_shape[self.axis] = 1;
let mut output_indices_view = output_indices.to_array_view_mut::<i64>()?;
for coords in tract_ndarray::indices(&*iterating_shape) {
Expand Down
6 changes: 3 additions & 3 deletions core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static +
self.eval_in_a(&mut a, &b)?;
Ok(a)
} else {
let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
let mut c = unsafe { Tensor::uninitialized_dt(c_dt, c_shape)? };
self.eval_out_of_place(&mut c, &a, &b)?;
Ok(c)
}
Expand Down Expand Up @@ -584,7 +584,7 @@ macro_rules! bin_to_super_type {
let b = b.to_array_view::<u8>()?;
let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])
.context("no broadcast solution")?;
let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
let mut c = Tensor::zero_dt(*c_dt, c_shape)?;
let view = c.to_array_view_mut::<u8>()?;
$crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
*c = (scale_by($q_op_on_f32(
Expand Down Expand Up @@ -613,7 +613,7 @@ macro_rules! bin_to_super_type {
let b = b.cast_to_dt(accumulator_dt)?.into_owned();
let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])
.context("no broadcast solution")?;
let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
let mut c = Tensor::zero_dt(accumulator_dt, c_shape)?;
match accumulator_dt {
DatumType::F32 => {
let view = c.to_array_view_mut::<f32>()?;
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl EvalOp for Cast {
if input.datum_type() == self.to {
Ok(tvec!(input))
} else if input.datum_type() == TDim::datum_type() {
let mut tmp = Tensor::zero_dt(i64::datum_type(), input.shape())?;
let mut tmp = Tensor::zero_dt(i64::datum_type(), input.shape().into())?;
for (dim, i) in
tract_itertools::izip!(input.as_slice::<TDim>()?, tmp.as_slice_mut::<i64>()?)
{
Expand Down
4 changes: 2 additions & 2 deletions core/src/ops/change_axes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ impl AxisOp {
Reshape(at, from, to) => {
let mut shape: TVec<usize> = tensor.shape().into();
self.change_shape_array(&mut shape, false)?;
if tensor.set_shape(&shape).is_ok() {
if tensor.set_shape(shape).is_ok() {
Ok(())
} else if broadcasting
&& tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
Expand Down Expand Up @@ -1116,7 +1116,7 @@ mod proptests {

fn input(&self) -> TractResult<Tensor> {
unsafe {
let mut t = Tensor::uninitialized::<i64>(&self.input)?;
let mut t = Tensor::uninitialized::<i64>(self.input.clone())?;
for i in 0..t.len() {
t.as_slice_mut().unwrap()[i] = i as i64;
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/cnn/conv/depth_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ macro_rules! impl_eval {
mul: impl Fn(T, T) -> T + Copy + 'static,
) -> TractResult<TVec<TValue>> {
let (img, kernel, bias) = args_3!(inputs);
let mut output = unsafe { Tensor::uninitialized::<T>(&dw.output_shape.shape)? };
let mut output = unsafe { Tensor::uninitialized::<T>(dw.output_shape.shape.clone())? };
let iptr = img.as_ptr::<T>()?;
let optr = output.as_ptr_mut::<T>()?;
let k_stride_i = kernel.strides()[1];
Expand Down
6 changes: 3 additions & 3 deletions core/src/ops/cnn/conv/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl EvalOp for Im2Col {
unsafe {
let mut input = inputs.remove(0).into_tensor();
let pad_value: Option<&Tensor> = if inputs.len() > 0 { Some(&inputs[0]) } else { None };
let mut output = Tensor::uninitialized::<Opaque>(&geometry.packed_shape)?;
let mut output = Tensor::uninitialized::<Opaque>(geometry.packed_shape.clone())?;
if !self.pool_spec.data_format.has_n() {
input.insert_axis(0)?;
}
Expand All @@ -162,7 +162,7 @@ impl EvalOp for Im2Col {
for g in 0..self.group {
let mut data = Tensor::uninitialized_aligned_dt(
input.datum_type(),
&[geometry.b_pack.len(geometry.k, geometry.n)],
tvec![geometry.b_pack.len(geometry.k, geometry.n)],
geometry.b_pack.alignment(),
)?;
dispatch_copy_by_size!(Patcher::patch(input.datum_type())(
Expand Down Expand Up @@ -264,7 +264,7 @@ impl Patcher {
) -> TractResult<()> {
unsafe {
let pad_value = *pad_value.to_scalar_unchecked();
let mut mega_matrix = Tensor::uninitialized::<T>(&[geometry.k, geometry.n])?;
let mut mega_matrix = Tensor::uninitialized::<T>(tvec![geometry.k, geometry.n])?;
let mut mega_matrix_view = mega_matrix.to_array_view_mut_unchecked::<T>();
let ptr = input.as_ptr_unchecked::<T>();
let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * (g * geometry.ci_per_group));
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/cnn/deconv/deconv_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl DeconvSum {
let mut tensor = bias.into_tensor();
let hw = *gemm.shape().last().unwrap();
let n = *output_shape.n().unwrap_or(&1);
let n_o_hkwk_hw = gemm.into_tensor().into_shape(&[
let n_o_hkwk_hw = gemm.into_tensor().into_shape(tvec![
n,
*output_shape.c(),
self.pool_spec.kernel_shape.iter().product(),
Expand Down
8 changes: 4 additions & 4 deletions core/src/ops/cnn/patches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ pub mod test {
fn reference_sumpool(&self) -> Tensor {
let input_shape = self.input_shape();
let output_shape = self.output_shape();
let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
let mut output = Tensor::zero::<f32>(output_shape.shape.clone()).unwrap();
for geo_out in tract_ndarray::indices(output_shape.hw_dims()) {
for geo_ker in tract_ndarray::indices(&*self.patch.spec.kernel_shape) {
let geo_in: TVec<isize> = izip!(
Expand Down Expand Up @@ -845,7 +845,7 @@ pub mod test {
fn check_visitor(&self) {
let input_shape = self.input_shape();
let output_shape = self.output_shape();
let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
let mut output = Tensor::zero::<f32>(output_shape.shape.clone()).unwrap();
self.patch.visit_output(|visitor| {
for (_k, offset_in) in visitor.valid_offsets_ker_in() {
for c in 0..*output_shape.c() {
Expand All @@ -862,7 +862,7 @@ pub mod test {
fn check_zone_visitor(&self) {
let input_shape = self.input_shape();
let output_shape = self.output_shape();
let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
let mut output = Tensor::zero::<f32>(output_shape.shape.clone()).unwrap();
for zone in &self.patch.zones {
zone.visit_output(&self.patch, |visitor| {
for (_k, offset_in) in visitor.valid_offsets_ker_in() {
Expand Down Expand Up @@ -945,7 +945,7 @@ pub mod test {
#[test]
fn test_visitor_1() {
let input_shape = NCHW.from_n_c_hw(1, 1, [2, 2]).unwrap();
let input = Tensor::zero::<f32>(&input_shape.shape).unwrap();
let input = Tensor::zero::<f32>(input_shape.shape.clone()).unwrap();
let patch = PatchSpec::for_data_shape(input_shape.clone())
.with_kernel_shape(tvec![2, 1])
.with_padding(PaddingSpec::SameLower)
Expand Down
10 changes: 6 additions & 4 deletions core/src/ops/cnn/sumpool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ impl EvalOp for LirSumPool {
let input = args_1!(inputs);
let geo = self.geometry.to_concrete(input.shape())?;
let values = if input.datum_type().is_float() {
let mut values =
unsafe { Tensor::uninitialized_dt(input.datum_type(), &geo.output_shape.shape)? };
let mut values = unsafe {
Tensor::uninitialized_dt(input.datum_type(), geo.output_shape.shape.clone())?
};
dispatch_floatlike!(Self::eval_t(input.datum_type())(
self,
&*input,
Expand All @@ -100,8 +101,9 @@ impl EvalOp for LirSumPool {
))?;
values
} else {
let mut values =
unsafe { Tensor::uninitialized_dt(DatumType::F32, &geo.output_shape.shape)? };
let mut values = unsafe {
Tensor::uninitialized_dt(DatumType::F32, geo.output_shape.shape.clone())?
};
let input_f32 = input.cast_to_dt(DatumType::F32)?;
self.eval_t::<f32>(input_f32.as_ref(), values.as_ptr_mut()?, geo.as_ref())?;
values.cast_to_dt(input.datum_type())?.into_owned()
Expand Down
7 changes: 5 additions & 2 deletions core/src/ops/downsample/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl EvalOp for Downsample {
let t = if self.modulo > input.shape()[self.axis] {
let mut shape: TVec<usize> = input.shape().into();
shape[self.axis] = 0;
Tensor::uninitialized_dt(input.datum_type(), &shape)?
Tensor::uninitialized_dt(input.datum_type(), shape)?
} else {
let slice = ndarray::Slice::new(self.modulo as isize, None, self.stride);
unsafe fn do_slice<T: Datum>(
Expand All @@ -86,7 +86,10 @@ impl EvalOp for Downsample {
impl TypedOp for Downsample {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(self.axis < inputs[0].rank());
ensure!(self.modulo == 0 || self.stride > 0, "non-zero modulo is only defined with forward strides");
ensure!(
self.modulo == 0 || self.stride > 0,
"non-zero modulo is only defined with forward strides"
);
let mut downed = inputs[0].clone();
let down_len = self.transform_dim(&downed.shape[self.axis]);
downed.shape.set(self.axis, down_len);
Expand Down
4 changes: 2 additions & 2 deletions core/src/ops/einsum/as_blas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ impl EvalOp for SGemm {
let n = c_shape[rank - 1];
let k = a.shape()[rank - 1];
unsafe {
let mut c = Tensor::uninitialized::<f32>(&c_shape)?;
let mut c = Tensor::uninitialized::<f32>(c_shape)?;
let c_ptr = c.as_ptr_mut::<f32>()?;
let silent_a_axis = c.rank() - a.rank();
let silent_b_axis = c.rank() - b.rank();
for prefix in ndarray::indices(&c_shape[0..rank - 2]) {
for prefix in ndarray::indices(&c.shape()[0..rank - 2]) {
let mut a_ptr = a_ptr;
let mut b_ptr = b_ptr;
let mut c_ptr = c_ptr;
Expand Down
Loading

0 comments on commit 7662183

Please sign in to comment.