From e60c1922ad7035f0eca69662496a196458bcfac9 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Thu, 24 Apr 2025 16:08:03 -0400 Subject: [PATCH 1/2] [Tosa] : Equalize all operands for select. --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 3 ++- test/Conversion/TorchToTosa/basic.mlir | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1bc2bedd660d..b72aff0a26cf 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5069,7 +5069,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( dyn_cast(getTypeConverter()->convertType(op.getType())); if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, self).failed() || - mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, other).failed()) + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, other).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, other).failed()) return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 52bb09b594fc..53fddd839495 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1421,6 +1421,28 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to return %0 : !torch.vtensor<[1,12,5,5],f32> } +// ----- +// CHECK-LABEL: func.func @torch.aten.where.self_differing_rank_inputs( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,4],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1,3,1,1,5,4],f32>) -> !torch.vtensor<[1,3,1,1,5,4],f32> { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[1,3,1,1,5,4],f32> -> tensor<1x3x1x1x5x4xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,4],i1> -> tensor<5x4xi1> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_6]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 1, 5, 4]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_8]] : (tensor<5x4xi1>, !tosa.shape<6>) -> tensor<1x1x1x1x5x4xi1> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_10]] : (tensor<1x1xf32>, !tosa.shape<6>) -> tensor<1x1x1x1x1x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_9]], %[[VAL_11]], %[[VAL_3]] : (tensor<1x1x1x1x5x4xi1>, tensor<1x1x1x1x1x1xf32>, tensor<1x3x1x1x5x4xf32>) -> tensor<1x3x1x1x5x4xf32> +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x3x1x1x5x4xf32> -> !torch.vtensor<[1,3,1,1,5,4],f32> +// CHECK: return %[[VAL_13]] +func.func @torch.aten.where.self_differing_rank_inputs(%40: !torch.vtensor<[5,4],i1>, %41: !torch.vtensor<[],f32>, %38 : !torch.vtensor<[1,3,1,1,5,4],f32>) -> (!torch.vtensor<[1,3,1,1,5,4],f32>) { + %42 = torch.aten.where.self %40, %41, %38 : !torch.vtensor<[5,4],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[1,3,1,1,5,4],f32> -> !torch.vtensor<[1,3,1,1,5,4],f32> + return %42: !torch.vtensor<[1,3,1,1,5,4],f32> +} + // ----- // CHECK-LABEL: func.func @torch.aten.remainder.Scalar( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { From 40e12d68ba80447cfe8dbc0a8214e5400cb17b07 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 25 Apr 2025 13:25:20 -0400 Subject: [PATCH 2/2] [Tosa] : Slice conv inputs for dynamic batch as long as spatial dims are static. --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 18 +++-- test/Conversion/TorchToTosa/basic.mlir | 85 ++++++++++++++++++++++ 2 files changed, 96 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b72aff0a26cf..749d49899885 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2453,9 +2453,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } int64_t outputHDim, outputWDim; - if (inputTy.hasStaticShape()) { - int64_t inputHDim = inputShape[2]; - int64_t inputWDim = inputShape[3]; + int64_t inputHDim = inputShape[2]; + int64_t inputWDim = inputShape[3]; + + bool isStaticSpatialDims = + !ShapedType::isDynamic(inputHDim) && !ShapedType::isDynamic(inputWDim); + if (isStaticSpatialDims) { + int64_t weightHDim = weightShape[2]; int64_t weightWDim = weightShape[3]; @@ -2473,8 +2477,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector sizeHSlice(transposedInputShape); // TOSA uses NHWC, so we will slice dim 1 for Height value sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]); - transposedInput = rewriter.create( - op->getLoc(), RankedTensorType::get(sizeHSlice, inputElemTy), + transposedInput = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), transposedInput, tosa::getTosaConstShape(rewriter, op->getLoc(), startHSlice), tosa::getTosaConstShape(rewriter, op->getLoc(), sizeHSlice)); @@ -2498,8 +2502,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( dyn_cast(transposedInput.getType()).getShape()); // TOSA uses NHWC, so we will slice dim 2 for Width value sizeWSlice[2] = inputWDim - (remainderWDim - padding[3]); - transposedInput = rewriter.create( - op->getLoc(), RankedTensorType::get(sizeWSlice, inputElemTy), + transposedInput = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), transposedInput, tosa::getTosaConstShape(rewriter, op->getLoc(), startWSlice), tosa::getTosaConstShape(rewriter, op->getLoc(), sizeWSlice)); diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 53fddd839495..91e67490a791 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -3757,6 +3757,91 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp return %5 : !torch.vtensor<[1,32,75,75],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,224,224],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor +// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor) -> tensor +// CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor to tensor +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor -> !torch.vtensor<[?,32,112,112],f32> +// CHECK: return %[[VAL_19]] + +func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(%arg0: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense_resource : tensor<32x3x3x3xf32>) : !torch.vtensor<[32,3,3,3],f32> + %none = torch.constant.none + %int2 = torch.constant.int 2 + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[?,3,224,224],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,32,112,112],f32> + return %5 : !torch.vtensor<[?,32,112,112],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,225,225],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor +// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor +// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array} : (tensor) -> tensor +// CHECK: %[[VAL_24:.*]] = tensor.cast %[[VAL_23]] : tensor to tensor +// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor -> !torch.vtensor<[?,32,75,75],f32> +// CHECK: return %[[VAL_25]] +func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(%arg0: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense_resource : tensor<32x3x3x3xf32>) : !torch.vtensor<[32,3,3,3],f32> + %none = torch.constant.none + %int3 = torch.constant.int 3 + %1 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[?,3,225,225],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,32,75,75],f32> + return %5 : !torch.vtensor<[?,32,75,75],f32> +} + // ----- // CHECK-LABEL: func.func @torch.aten.max_pool2d$zero_pad_with_sliced_input(