From aa7e7f7e892da927baa3443137f8c7c879820e65 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Thu, 17 Apr 2025 11:46:34 -0400 Subject: [PATCH 1/8] Average pooling clamped divisor should be done on all conditions where the kernel can go out of bounds. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 28 +++--- .../torch_mlir_e2e_test/test_suite/pooling.py | 89 +++++++++++++++++++ test/Conversion/TorchToLinalg/pooling.mlir | 28 ++++++ 3 files changed, 134 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index ecfa8b8a5865..3c541d87f665 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -961,9 +961,9 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { // count_include_pad parameter is equal to false. static std::optional createAvgPoolValueCountIncludePadFalseCase( - bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter, Value self, Value sumPool, - Value outputTensor, Type resultType, + bool ceilMode, bool countIncludePad, OpTy op, + typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, + Value self, Value sumPool, Value outputTensor, Type resultType, SmallVectorImpl &kernelSizeIntValues, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, @@ -1041,9 +1041,9 @@ LogicalResult ConvertAtenAvgPoolOp::matchAndRewrite( Dim + 2, utils::IteratorType::parallel); auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase( - countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor, - resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg, - iteratorTypesAvg); + ceilMode, countIncludePad, op, adaptor, rewriter, self, sumPool, + outputTensor, resultType, kernelSizeIntValues, strideInts, paddingInts, + indexingMapsAvg, iteratorTypesAvg); if (divisorOpResult) return *divisorOpResult; @@ -1057,9 +1057,9 @@ LogicalResult ConvertAtenAvgPoolOp::matchAndRewrite( template std::optional ConvertAtenAvgPoolOp:: createAvgPoolValueCountIncludePadFalseCase( - bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter, Value self, Value sumPool, - Value outputTensor, Type resultType, + bool ceilMode, bool countIncludePad, OpTy op, + typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, + Value self, Value sumPool, Value outputTensor, Type resultType, SmallVectorImpl &kernelSizeIntValues, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, @@ -1069,8 +1069,14 @@ std::optional ConvertAtenAvgPoolOp:: constexpr int avgPoolDims = getAvgPoolNumOfDims(); - bool noPadding = llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); - if (countIncludePad || noPadding) { + bool hasPadding = + !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); + bool allStridesUnitary = + llvm::all_of(strideInts, [](int64_t s) { return s == 1; }); + bool canKernelWindowGoOutOfBounds = + hasPadding || (ceilMode && !allStridesUnitary); + + if (countIncludePad || !canKernelWindowGoOutOfBounds) { // These cases are not handled here. return std::nullopt; } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index c6cc264d6aff..2ffa7bd772f1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -2302,3 +2302,92 @@ def MaxUnpool3dModulePad0_basic(module, tu: TestUtils): output, indices = pool(input) module.forward(output, indices) + + +class AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa(torch.nn.Module): + # This test captures the torch-mlir issue reported here: + # https://github.com/llvm/torch-mlir/issues/4079 + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa()) +def AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilNoPadUnitaryStrides(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadUnitaryStrides()) +def AvgPool2dCeilNoPadUnitaryStrides_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPadNonUnitaryStrides(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[1, 1], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilPadNonUnitaryStrides()) +def AvgPool2dCeilPadNonUnitaryStrides_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 53faa1d37d4f..81ce07bba848 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -225,3 +225,31 @@ func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512 %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,12],f32> return %3 : !torch.vtensor<[1,512,12],f32> } + +// CHECK-LABEL: func @forward_avgpool_2d_ceil +func.func @forward_avgpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[POOL_OUT:.*]] = linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%[[PADDED_IN:.*]], %[[KERNEL_IN:.*]] : tensor<1x1x6x6xf32>, tensor<3x3xf32>) outs(%[[OUT1:.*]] : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> + // CHECK: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL_OUT]] : tensor<1x1x2x2xf32>) outs(%[[GEN_OUT:.*]] : tensor<1x1x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-3: arith.muli + // CHECK-COUNT-1: arith.sitofp + // CHECK-COUNT-1: arith.divf + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x1x2x2xf32> + %int3 = torch.constant.int 3 + %int3_0 = torch.constant.int 3 + %int0 = torch.constant.int 0 + %int0_1 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int2_2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int1_3 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int3, %int3_0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int0, %int0_1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int2_2, %int1, %int1_3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %true = torch.constant.bool true + %false = torch.constant.bool false + %none = torch.constant.none + %3 = torch.aten.avg_pool2d %arg0, %0, %2, %1, %true, %false, %none : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,2,2],f32> + return %3 : !torch.vtensor<[1,1,2,2],f32> +} From 0e965b8eada32a722c07cbf05fcc31df9d34eefc Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Wed, 23 Apr 2025 17:01:39 -0400 Subject: [PATCH 2/8] Updated divisor algorithm after finding SWA in existing logic (prior any change of mine). --- lib/Conversion/TorchToLinalg/Pooling.cpp | 192 +++++++++++----- projects/pt1/e2e_testing/xfail_sets.py | 10 + .../torch_mlir_e2e_test/test_suite/pooling.py | 217 ++++++++++++++++++ test/Conversion/TorchToLinalg/pooling.mlir | 2 +- 4 files changed, 366 insertions(+), 55 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 3c541d87f665..aeb6de15cecc 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -856,7 +856,7 @@ namespace { // used in the divisor of the average pooling operator. template class PoolSizeCalculator { public: - PoolSizeCalculator(Value self, Value sumPool, + PoolSizeCalculator(Value self, Value sumPool, bool countIncludePad, ConversionPatternRewriter &rewriter, Location loc); // The algorithm for computing the divisor with @@ -871,18 +871,19 @@ template class PoolSizeCalculator { SmallVectorImpl &paddingInts); private: - int64_t DimSizeFromSumPoolType[NumOfDims]; - Value InputSpatialDimValues[NumOfDims]; + int64_t SumPoolTypeDimIndex[NumOfDims]; + Value InputSpatialDimSizes[NumOfDims]; Location location; + bool isCountIncludePad; }; } // namespace template PoolSizeCalculator::PoolSizeCalculator( - Value self, Value sumPool, ConversionPatternRewriter &rewriter, - Location loc) - : location(loc) { + Value self, Value sumPool, bool countIncludePad, + ConversionPatternRewriter &rewriter, Location loc) + : location(loc), isCountIncludePad(countIncludePad) { auto selfType = cast(self.getType()); const int64_t selfRank = selfType.getRank(); RankedTensorType sumPoolType = cast(sumPool.getType()); @@ -891,57 +892,117 @@ PoolSizeCalculator::PoolSizeCalculator( // Store dimensions in this order: // 0 => width, 1 => height, 2 => depth for (int i = 0; i < NumOfDims; ++i) { - int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank); - InputSpatialDimValues[i] = - getDimOp(rewriter, location, self, DimSizeFromSelfType); - DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank); + int64_t inputSpatialDimIndex = toPositiveDim(-(i + 1), selfRank); + InputSpatialDimSizes[i] = + getDimOp(rewriter, location, self, inputSpatialDimIndex); + SumPoolTypeDimIndex[i] = toPositiveDim(-(i + 1), rank); } } template Value PoolSizeCalculator::getPoolSize( - OpBuilder &b, SmallVectorImpl &kernelSizeIntValues, + OpBuilder &b, SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts) { Value poolSize; Value cstZero = b.createOrFold(location, b.getI64IntegerAttr(0)); + Value cstOne = + b.createOrFold(location, b.getI64IntegerAttr(1)); + Value cstTwo = + b.createOrFold(location, b.getI64IntegerAttr(2)); for (int i = 0; i < NumOfDims; ++i) { - // See the link below for the PyTorch implementation where this is - // derived from: - // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 - // Dim below stands for spatial dimension. Prior to the February 2025 - // change, these variables used "height" and "width" (or "h" and "w") - // in these intermediate variables instead of "Dim". - Value IndexODim = + // The following code computes the clamped kernel size used to compute + // the divisor of the average pooling operator. Here is the formula that + // it represents: + // + // indexStartOffset = ceil((kernelSize - 1)/2) - padding + // + // clampedKernelSize = + // min(outIntIndex * stride + indexStartOffset + floor((kernelSize - 1)/2) + // + 1, + // InputSpatialDimSize + padding) - + // max(outIntIndex * stride + indexStartOffset - ceil((kernelSize - 1)/2), + // -padding) + // + // The outIntIndex is the current iteration value coming from the + // linalg.generic op and it represents the center of the kernel window. + // The padding above becomes zero if count_include_pad is false. + // The kernelSize - 1 is used to subtract the center element of the kernel + // from the kernel size before dividing by two. Note that PyTorch even + // kernel dimensions are biased to the lower side of the dimension. Hence + // the lower length uses ceiling. While the upper length uses floor. + // + // If count_include_pad is true, in most cases the divisor is just the + // product of kernel dimensions. But we still need this logic for the + // case in which the ceiling mode is true since the kernel window + // center can go into the padding outside of the input tensor. This + // introduces an implicit padding that is not controlled by the + // count_include_pad parameter. See the + // AvgPool2dCeilPaddingStridedIncludePadding E2E test for details. + + Value padding = b.createOrFold( + location, b.getI64IntegerAttr(paddingInts[i])); + Value InputSpatialDimSize = + castIndexToInt64(b, location, InputSpatialDimSizes[i]); + // Subtract center element from kernel size before division by two. + Value kernelSizeMinusOne = + b.createOrFold(location, kernelDimSizes[i], cstOne); + // PyTorch even kernel dimensions are biased to the lower side of the + // dimension. Hence the lower lenght uses ceiling. + Value kernelLowerLength = b.createOrFold( + location, kernelSizeMinusOne, cstTwo); + // While the upper length uses floor. + Value kernelUpperLength = b.createOrFold( + location, kernelSizeMinusOne, cstTwo); + + // The more padding the closest we can read from the lower bound of + // the input tensor. + Value indexStartOffset = + b.createOrFold(location, kernelLowerLength, padding); + + Value outIndex = b.createOrFold(location, - /*value=*/DimSizeFromSumPoolType[i]); - Value ODim = castIndexToInt64(b, location, IndexODim); - Value DDim = b.createOrFold( + /*value=*/SumPoolTypeDimIndex[i]); + Value outIntIndex = castIndexToInt64(b, location, outIndex); + + Value stride = b.createOrFold( location, b.getI64IntegerAttr(strideInts[i])); - Value PadDim = b.createOrFold( - location, b.getI64IntegerAttr(paddingInts[i])); - Value ODimDDim = b.createOrFold(location, ODim, DDim); - Value IDim0 = b.createOrFold(location, ODimDDim, PadDim); - Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]); - Value IDim0KDim = - b.createOrFold(location, IDim0, kernelSizeIntValues[i]); - Value IDimPadDim = b.createOrFold(location, IDim, PadDim); - Value IDim1 = - b.createOrFold(location, IDim0KDim, IDimPadDim); - - Value IDim0Clamped = - b.createOrFold(location, IDim0, cstZero); - Value IDim1Clamped = b.createOrFold(location, IDim1, IDim); - Value IDim1_IDim0_Clamped = - b.createOrFold(location, IDim1Clamped, IDim0Clamped); + + Value indexStrided = b.createOrFold( + location, b.createOrFold(location, outIntIndex, stride), + indexStartOffset); + + Value inputUpperBound = isCountIncludePad + ? b.createOrFold( + location, InputSpatialDimSize, padding) + : InputSpatialDimSize; + Value inputLowerBound = + isCountIncludePad + ? b.createOrFold(location, cstZero, padding) + : cstZero; + + Value upperBoundMinusOne = b.createOrFold( + location, indexStrided, kernelUpperLength); + Value upperBound = + b.createOrFold(location, upperBoundMinusOne, cstOne); + Value upperBoundClamped = + b.createOrFold(location, upperBound, inputUpperBound); + + Value lowerBound = b.createOrFold(location, indexStrided, + kernelLowerLength); + Value lowerBoundClamped = + b.createOrFold(location, lowerBound, inputLowerBound); + Value clampedKernelSize = b.createOrFold( + location, upperBoundClamped, lowerBoundClamped); + if (i == 0) { - poolSize = IDim1_IDim0_Clamped; + poolSize = clampedKernelSize; } else { - poolSize = b.createOrFold(location, poolSize, - IDim1_IDim0_Clamped); + poolSize = + b.createOrFold(location, poolSize, clampedKernelSize); } } return poolSize; @@ -964,7 +1025,7 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { bool ceilMode, bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, Value self, Value sumPool, Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVector &indexingMapsAvg, @@ -976,7 +1037,7 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { OpTy op, typename OpTy::Adaptor &adaptor, ConversionPatternRewriter &rewriter, Value self, Value sumPool, Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &kernelDimSizes, SmallVector &indexingMapsAvg, SmallVector &iteratorTypesAvg); }; @@ -1060,7 +1121,7 @@ std::optional ConvertAtenAvgPoolOp:: bool ceilMode, bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, Value self, Value sumPool, Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVector &indexingMapsAvg, @@ -1073,10 +1134,33 @@ std::optional ConvertAtenAvgPoolOp:: !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); bool allStridesUnitary = llvm::all_of(strideInts, [](int64_t s) { return s == 1; }); - bool canKernelWindowGoOutOfBounds = - hasPadding || (ceilMode && !allStridesUnitary); - if (countIncludePad || !canKernelWindowGoOutOfBounds) { + // If the condition below is true, the divisor total must subtract the + // elements not counted (clamped divisor count). If false, the divisor + // is just the product of kernel dimensions. + bool divisorIsClamped = + (!countIncludePad && hasPadding) || (ceilMode && !allStridesUnitary); + // There are two ways to get the divisor clamped: through padding or + // ceiling mode. For the case when there is padding, the padding elements + // are omitted if count_include_pad == False (divisor is clamped). If + // there is no padding (padding == 0) then the count_include_pad value + // does not take effect. + // The divisor count can be clamped also through the ceil_mode. In this + // case, according to the Hout and Wout formula in this page: + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d, + // the ceil_mode will round up on the stride division. The round up + // will give an extra element that will go out of bounds which PyTorch + // adds zero padding in it. It also does not count the implicit zero + // padding elements in the divisor, and it is not controlled by the + // count_include_pad argument. + // But also note that if all strides are 1 there is not fractions to + // round up, hence there is no ceiling rounding and the window will + // not go out of bounds. For this case the divisor is just the + // product of kernel dimensions. + // Search for torch.nn.AvgPool2d E2E tests for coverage of these + // conditions. + + if (!divisorIsClamped) { // These cases are not handled here. return std::nullopt; } @@ -1088,8 +1172,8 @@ std::optional ConvertAtenAvgPoolOp:: Type resultElementType = cast(resultType).getElementType(); - PoolSizeCalculator poolSizeCalculator(self, sumPool, rewriter, - loc); + PoolSizeCalculator poolSizeCalculator( + self, sumPool, countIncludePad, rewriter, loc); // AtenAvgPool2/3dOp has an optional divisor_override // attribute while AtenAvgPool1dOp does not. @@ -1110,7 +1194,7 @@ std::optional ConvertAtenAvgPoolOp:: [&](OpBuilder &b, Location loc, ValueRange args) { if (!poolSize) { poolSize = poolSizeCalculator.getPoolSize( - b, kernelSizeIntValues, strideInts, paddingInts); + b, kernelDimSizes, strideInts, paddingInts); } Value divisor = convertScalarToDtype(b, loc, poolSize, resultElementType); @@ -1132,17 +1216,17 @@ LogicalResult ConvertAtenAvgPoolOp:: OpTy op, typename OpTy::Adaptor &adaptor, ConversionPatternRewriter &rewriter, Value self, Value sumPool, Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &kernelDimSizes, SmallVector &indexingMapsAvg, SmallVector &iteratorTypesAvg) { Location loc = op->getLoc(); Type resultElementType = cast(resultType).getElementType(); - Value divisor = kernelSizeIntValues[0]; - for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) { - divisor = rewriter.createOrFold(loc, divisor, - kernelSizeIntValues[i]); + Value divisor = kernelDimSizes[0]; + for (uint32_t i = 1; i < kernelDimSizes.size(); ++i) { + divisor = + rewriter.createOrFold(loc, divisor, kernelDimSizes[i]); } // Only average pooling 2D/3D have optional divisor override. if constexpr (!std::is_same()) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 202378d1f9ac..4788eb49e48b 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -650,6 +650,10 @@ "Aten_EmbeddingBagExample_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", + "AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic", + "AvgPool2dCeilPadNonUnitaryStrides_basic", + "AvgPool2dCeilNoPadStridedIncludePadding_basic", + "AvgPool2dCeilPaddingStridedIncludePadding_basic", "AvgPool2dDivisorOverrideModule_basic", "BernoulliTensorModule_basic", "BincountMinlengthModule_basic", @@ -3527,6 +3531,9 @@ "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", "AvgPool2dCeilModeTrueModule_basic", + "AvgPool2dCeilPaddingStridedIncludePadding_basic", + "AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic", + "AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic", "AvgPool2dDivisorOverrideModule_basic", "AvgPool2dFloatModule_basic", "AvgPool2dIntModule_basic", @@ -3932,6 +3939,9 @@ "AtenKthvalueFloat64Module_basic", "AtenKthvalueKeepDimModule_basic", "AtenKthvalueModule_basic", + "AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic", + "AvgPool2dCeilNoPadUnitaryStrides_basic", + "AvgPool2dCeilPadNonUnitaryStrides_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", "AvgPool3dStaticModule_basic", "Conv_Transpose1dModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 2ffa7bd772f1..246863adb6a6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -2391,3 +2391,220 @@ def forward(self, x): @register_test_case(module_factory=lambda: AvgPool2dCeilPadNonUnitaryStrides()) def AvgPool2dCeilPadNonUnitaryStrides_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilNoPadStridedIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[0, 0], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadStridedIncludePadding()) +def AvgPool2dCeilNoPadStridedIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilNoPadUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilNoPadUnitaryStrideIncludePadding() +) +def AvgPool2dCeilNoPadUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse() +) +def AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dFloorNoPadUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dFloorNoPadUnitaryStrideIncludePadding() +) +def AvgPool2dFloorNoPadUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dFloorPaddingUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dFloorPaddingUnitaryStrideIncludePadding() +) +def AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilPaddingUnitaryStrideIncludePadding() +) +def AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingStridedIncludePadding(torch.nn.Module): + # Note that in this case the kernel window center will go into the padding. + # When this happens the padding elements are counted in the divisor, but + # the out of bound elements from the ceiling are not counted + # (i.e., clamped from the divisor count). + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[1, 1], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilPaddingStridedIncludePadding()) +def AvgPool2dCeilPaddingStridedIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 81ce07bba848..f94c6e864f1f 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -227,7 +227,7 @@ func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512 } // CHECK-LABEL: func @forward_avgpool_2d_ceil -func.func @forward_avgpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @forward_avgpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> { // CHECK: %[[POOL_OUT:.*]] = linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%[[PADDED_IN:.*]], %[[KERNEL_IN:.*]] : tensor<1x1x6x6xf32>, tensor<3x3xf32>) outs(%[[OUT1:.*]] : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> // CHECK: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL_OUT]] : tensor<1x1x2x2xf32>) outs(%[[GEN_OUT:.*]] : tensor<1x1x2x2xf32>) { // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): From 4cc39d9a65ad629bbbea0097c4d771144b533388 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Thu, 24 Apr 2025 07:47:58 -0400 Subject: [PATCH 3/8] Update patterns in MLIR unit tests. --- test/Conversion/TorchToLinalg/pooling.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index f94c6e864f1f..ea0747cb4753 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -124,7 +124,7 @@ func.func @forward_avg_pool2d_countincludepad_false(%arg0: !torch.vtensor<[1,3,6 // CHECK: linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<[1, 2]> : vector<2xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x3x64x58xf32>, tensor<4x5xf32>) outs(%[[OUT1:.*]] : tensor<1x3x61x27xf32>) -> tensor<1x3x61x27xf32> // CHECK: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x61x27xf32>) outs(%[[OUT2:.*]] : tensor<1x3x61x27xf32>) // CHECK-NEXT: ^bb0(%[[BIIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-4: arith.minsi + // CHECK-COUNT-1: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x3x61x27xf32> @@ -171,7 +171,7 @@ func.func @forward_avg_pool3dd_countincludepad_false(%arg0: !torch.vtensor<[1,3, // CHECK: linalg.pooling_ndhwc_sum {dilations = dense<1> : vector<3xi64>, strides = dense<[1, 2, 1]> : vector<3xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x7x66x58x3xf32>, tensor<4x5x5xf32>) outs(%[[OUT1:.*]] : tensor<1x4x31x54x3xf32>) -> tensor<1x4x31x54x3xf32> // CHECK: linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x4x31x54xf32>) outs(%[[OUT2:.*]] : tensor<1x3x4x31x54xf32>) // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-6: arith.minsi + // CHECK-COUNT-3: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x3x4x31x54xf32> @@ -213,7 +213,7 @@ func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512 // CHECK: linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%[[IN1:.*]], %[[IN2:.*]] : tensor<1x512x12xf32>, tensor<1xf32>) outs(%[[OUT1:.*]] : tensor<1x512x12xf32>) -> tensor<1x512x12xf32> // CHECK: linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[IN3:.*]] : tensor<1x512x12xf32>) outs(%[[OUT2:.*]] : tensor<1x512x12xf32> // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-2: arith.minsi + // CHECK-COUNT-1: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x512x12xf32> From 8bd5e707df48ebad4f3eb9ed87161321d3f29705 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Thu, 24 Apr 2025 13:10:48 -0400 Subject: [PATCH 4/8] Adding more tests and fixing issue oncovered by one of them; i.e., kernel/stride/padding elements have to be processed in reversed order relative to the spatial dimensions. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 25 ++- projects/pt1/e2e_testing/xfail_sets.py | 5 + .../torch_mlir_e2e_test/test_suite/pooling.py | 185 ++++++++++++++++++ 3 files changed, 207 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index aeb6de15cecc..80ac7bcb0199 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -857,7 +857,8 @@ namespace { template class PoolSizeCalculator { public: PoolSizeCalculator(Value self, Value sumPool, bool countIncludePad, - ConversionPatternRewriter &rewriter, Location loc); + bool ceilMode, ConversionPatternRewriter &rewriter, + Location loc); // The algorithm for computing the divisor with // count_include_pad equal is mainly based on pytorch @@ -875,15 +876,16 @@ template class PoolSizeCalculator { Value InputSpatialDimSizes[NumOfDims]; Location location; bool isCountIncludePad; + bool isCeilMode; }; } // namespace template PoolSizeCalculator::PoolSizeCalculator( - Value self, Value sumPool, bool countIncludePad, + Value self, Value sumPool, bool countIncludePad, bool ceilMode, ConversionPatternRewriter &rewriter, Location loc) - : location(loc), isCountIncludePad(countIncludePad) { + : location(loc), isCountIncludePad(countIncludePad), isCeilMode(ceilMode) { auto selfType = cast(self.getType()); const int64_t selfRank = selfType.getRank(); RankedTensorType sumPoolType = cast(sumPool.getType()); @@ -943,13 +945,19 @@ Value PoolSizeCalculator::getPoolSize( // count_include_pad parameter. See the // AvgPool2dCeilPaddingStridedIncludePadding E2E test for details. + // The average pool properties of kernel size, strides, and padding are + // stored in the reverse order of the input tensor dimensions. The + // following code computes the index of the average pool property that + // corresponds to the current spatial dimension. + int avgPoolPropIdx = NumOfDims - i - 1; + Value padding = b.createOrFold( - location, b.getI64IntegerAttr(paddingInts[i])); + location, b.getI64IntegerAttr(paddingInts[avgPoolPropIdx])); Value InputSpatialDimSize = castIndexToInt64(b, location, InputSpatialDimSizes[i]); // Subtract center element from kernel size before division by two. - Value kernelSizeMinusOne = - b.createOrFold(location, kernelDimSizes[i], cstOne); + Value kernelSizeMinusOne = b.createOrFold( + location, kernelDimSizes[avgPoolPropIdx], cstOne); // PyTorch even kernel dimensions are biased to the lower side of the // dimension. Hence the lower lenght uses ceiling. Value kernelLowerLength = b.createOrFold( @@ -969,7 +977,7 @@ Value PoolSizeCalculator::getPoolSize( Value outIntIndex = castIndexToInt64(b, location, outIndex); Value stride = b.createOrFold( - location, b.getI64IntegerAttr(strideInts[i])); + location, b.getI64IntegerAttr(strideInts[avgPoolPropIdx])); Value indexStrided = b.createOrFold( location, b.createOrFold(location, outIntIndex, stride), @@ -979,6 +987,7 @@ Value PoolSizeCalculator::getPoolSize( ? b.createOrFold( location, InputSpatialDimSize, padding) : InputSpatialDimSize; + Value inputLowerBound = isCountIncludePad ? b.createOrFold(location, cstZero, padding) @@ -1173,7 +1182,7 @@ std::optional ConvertAtenAvgPoolOp:: Type resultElementType = cast(resultType).getElementType(); PoolSizeCalculator poolSizeCalculator( - self, sumPool, countIncludePad, rewriter, loc); + self, sumPool, countIncludePad, ceilMode, rewriter, loc); // AtenAvgPool2/3dOp has an optional divisor_override // attribute while AtenAvgPool1dOp does not. diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4788eb49e48b..a25502899fa7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -654,6 +654,9 @@ "AvgPool2dCeilPadNonUnitaryStrides_basic", "AvgPool2dCeilNoPadStridedIncludePadding_basic", "AvgPool2dCeilPaddingStridedIncludePadding_basic", + "AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "AvgPool2dDivisorOverrideModule_basic", "BernoulliTensorModule_basic", "BincountMinlengthModule_basic", @@ -3534,6 +3537,8 @@ "AvgPool2dCeilPaddingStridedIncludePadding_basic", "AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic", "AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic", + "AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "AvgPool2dDivisorOverrideModule_basic", "AvgPool2dFloatModule_basic", "AvgPool2dIntModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 246863adb6a6..d66d87f06a8c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -2608,3 +2608,188 @@ def forward(self, x): @register_test_case(module_factory=lambda: AvgPool2dCeilPaddingStridedIncludePadding()) def AvgPool2dCeilPaddingStridedIncludePadding_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded(torch.nn.Module): + # Different sizes used for each kernel and stride.dimensions. No padding. + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 2], + stride=[2, 3], + padding=[0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded() +) +def AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, low=-1)) + + +class AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded(torch.nn.Module): + # Different sizes used for each kernel, stride, and padding.dimensions. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 4], + stride=[2, 3], + padding=[1, 2], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded() +) +def AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, low=-1)) + + +class AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded(torch.nn.Module): + # 3D version of AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[3, 2, 4], + stride=[3, 2, 5], + padding=[0, 0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 5, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded() +) +def AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 5, 7, low=-1)) + + +class AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded(torch.nn.Module): + # 3-D version of AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[3, 4, 7], + stride=[2, 3, 4], + padding=[1, 2, 3], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded() +) +def AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, 7, low=-1)) + + +class AvgPool1dNoPadCeilPadNotIncluded(torch.nn.Module): + # 1D version of AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool1d( + kernel_size=[2], + stride=[2], + padding=[1], + ceil_mode=True, + count_include_pad=False, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 5], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool1dNoPadCeilPadNotIncluded()) +def AvgPool1dNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 5, low=-1)) + + +class AvgPool1dPadCeilPadNotIncluded(torch.nn.Module): + # 1-D version of AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool1d( + kernel_size=[2], + stride=[2], + padding=[1], + ceil_mode=True, + count_include_pad=False, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool1dPadCeilPadNotIncluded()) +def AvgPool1dPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, low=-1)) From b92d0493c1b51762089f80a94e280f730df9fd15 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Thu, 24 Apr 2025 14:03:33 -0400 Subject: [PATCH 5/8] Filtering new tests on ONNX test suite. --- projects/pt1/e2e_testing/xfail_sets.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a25502899fa7..9f465c42adef 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2795,6 +2795,10 @@ "AvgPool2dSingleIntTupleParamsIncludePadModule_basic", "AvgPool2dSingleIntTupleParamsModule_basic", "AvgPool2dWithoutPadModule_basic", + "AvgPool1dNoPadCeilPadNotIncluded_basic", + "AvgPool1dPadCeilPadNotIncluded_basic", + "AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "BatchMlpLayerModule_basic", "BincountMinlengthModule_basic", "BincountModule_basic", From 247263ca9ada4ab108f17bc9745afccda40f2027 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Tue, 29 Apr 2025 17:58:51 -0400 Subject: [PATCH 6/8] Merging with Vivek's change. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 276 +++++++--- projects/pt1/e2e_testing/xfail_sets.py | 21 + .../torch_mlir_e2e_test/test_suite/pooling.py | 491 ++++++++++++++++++ test/Conversion/TorchToLinalg/pooling.mlir | 34 +- 4 files changed, 734 insertions(+), 88 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 45268452a992..e1f3c66a891c 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -856,8 +856,9 @@ namespace { // used in the divisor of the average pooling operator. template class PoolSizeCalculator { public: - PoolSizeCalculator(Value self, Value sumPool, - ConversionPatternRewriter &rewriter, Location loc); + PoolSizeCalculator(Value self, Value sumPool, bool countIncludePad, + bool ceilMode, ConversionPatternRewriter &rewriter, + Location loc); // The algorithm for computing the divisor with // count_include_pad equal is mainly based on pytorch @@ -871,18 +872,20 @@ template class PoolSizeCalculator { SmallVectorImpl &paddingInts); private: - int64_t DimSizeFromSumPoolType[NumOfDims]; - Value InputSpatialDimValues[NumOfDims]; + int64_t SumPoolTypeDimIndex[NumOfDims]; + Value InputSpatialDimSizes[NumOfDims]; Location location; + bool isCountIncludePad; + bool isCeilMode; }; } // namespace template PoolSizeCalculator::PoolSizeCalculator( - Value self, Value sumPool, ConversionPatternRewriter &rewriter, - Location loc) - : location(loc) { + Value self, Value sumPool, bool countIncludePad, bool ceilMode, + ConversionPatternRewriter &rewriter, Location loc) + : location(loc), isCountIncludePad(countIncludePad), isCeilMode(ceilMode) { auto selfType = cast(self.getType()); const int64_t selfRank = selfType.getRank(); RankedTensorType sumPoolType = cast(sumPool.getType()); @@ -891,57 +894,124 @@ PoolSizeCalculator::PoolSizeCalculator( // Store dimensions in this order: // 0 => width, 1 => height, 2 => depth for (int i = 0; i < NumOfDims; ++i) { - int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank); - InputSpatialDimValues[i] = - getDimOp(rewriter, location, self, DimSizeFromSelfType); - DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank); + int64_t inputSpatialDimIndex = toPositiveDim(-(i + 1), selfRank); + InputSpatialDimSizes[i] = + getDimOp(rewriter, location, self, inputSpatialDimIndex); + SumPoolTypeDimIndex[i] = toPositiveDim(-(i + 1), rank); } } template Value PoolSizeCalculator::getPoolSize( - OpBuilder &b, SmallVectorImpl &kernelSizeIntValues, + OpBuilder &b, SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts) { Value poolSize; Value cstZero = b.createOrFold(location, b.getI64IntegerAttr(0)); + Value cstOne = + b.createOrFold(location, b.getI64IntegerAttr(1)); + Value cstTwo = + b.createOrFold(location, b.getI64IntegerAttr(2)); for (int i = 0; i < NumOfDims; ++i) { - // See the link below for the PyTorch implementation where this is - // derived from: - // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 - // Dim below stands for spatial dimension. Prior to the February 2025 - // change, these variables used "height" and "width" (or "h" and "w") - // in these intermediate variables instead of "Dim". - Value IndexODim = + // The following code computes the clamped kernel size used to compute + // the divisor of the average pooling operator. Here is the formula that + // it represents: + // + // indexStartOffset = ceil((kernelSize - 1)/2) - padding + // + // clampedKernelSize = + // min(outIntIndex * stride + indexStartOffset + floor((kernelSize - 1)/2) + // + 1, + // InputSpatialDimSize + padding) - + // max(outIntIndex * stride + indexStartOffset - ceil((kernelSize - 1)/2), + // -padding) + // + // The outIntIndex is the current iteration value coming from the + // linalg.generic op and it represents the center of the kernel window. + // The padding above becomes zero if count_include_pad is false. + // The kernelSize - 1 is used to subtract the center element of the kernel + // from the kernel size before dividing by two. Note that PyTorch even + // kernel dimensions are biased to the lower side of the dimension. Hence + // the lower length uses ceiling. While the upper length uses floor. + // + // If count_include_pad is true, in most cases the divisor is just the + // product of kernel dimensions. But we still need this logic for the + // case in which the ceiling mode is true since the kernel window + // center can go into the padding outside of the input tensor. This + // introduces an implicit padding that is not controlled by the + // count_include_pad parameter. See the + // AvgPool2dCeilPaddingStridedIncludePadding E2E test for details. + + // The average pool properties of kernel size, strides, and padding are + // stored in the reverse order of the input tensor dimensions. The + // following code computes the index of the average pool property that + // corresponds to the current spatial dimension. + int avgPoolPropIdx = NumOfDims - i - 1; + + Value padding = b.createOrFold( + location, b.getI64IntegerAttr(paddingInts[avgPoolPropIdx])); + Value InputSpatialDimSize = + castIndexToInt64(b, location, InputSpatialDimSizes[i]); + // Subtract center element from kernel size before division by two. + Value kernelSizeMinusOne = b.createOrFold( + location, kernelDimSizes[avgPoolPropIdx], cstOne); + // PyTorch even kernel dimensions are biased to the lower side of the + // dimension. Hence the lower lenght uses ceiling. + Value kernelLowerLength = b.createOrFold( + location, kernelSizeMinusOne, cstTwo); + // While the upper length uses floor. + Value kernelUpperLength = b.createOrFold( + location, kernelSizeMinusOne, cstTwo); + + // The more padding the closest we can read from the lower bound of + // the input tensor. + Value indexStartOffset = + b.createOrFold(location, kernelLowerLength, padding); + + Value outIndex = b.create(location, - /*value=*/DimSizeFromSumPoolType[i]); - Value ODim = castIndexToInt64(b, location, IndexODim); - Value DDim = b.createOrFold( - location, b.getI64IntegerAttr(strideInts[i])); - Value PadDim = b.createOrFold( - location, b.getI64IntegerAttr(paddingInts[i])); - Value ODimDDim = b.createOrFold(location, ODim, DDim); - Value IDim0 = b.createOrFold(location, ODimDDim, PadDim); - Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]); - Value IDim0KDim = - b.createOrFold(location, IDim0, kernelSizeIntValues[i]); - Value IDimPadDim = b.createOrFold(location, IDim, PadDim); - Value IDim1 = - b.createOrFold(location, IDim0KDim, IDimPadDim); - - Value IDim0Clamped = - b.createOrFold(location, IDim0, cstZero); - Value IDim1Clamped = b.createOrFold(location, IDim1, IDim); - Value IDim1_IDim0_Clamped = - b.createOrFold(location, IDim1Clamped, IDim0Clamped); + /*value=*/SumPoolTypeDimIndex[i]); + Value outIntIndex = castIndexToInt64(b, location, outIndex); + + Value stride = b.createOrFold( + location, b.getI64IntegerAttr(strideInts[avgPoolPropIdx])); + + Value indexStrided = b.createOrFold( + location, b.createOrFold(location, outIntIndex, stride), + indexStartOffset); + + Value inputUpperBound = isCountIncludePad + ? b.createOrFold( + location, InputSpatialDimSize, padding) + : InputSpatialDimSize; + + Value inputLowerBound = + isCountIncludePad + ? b.createOrFold(location, cstZero, padding) + : cstZero; + + Value upperBoundMinusOne = b.createOrFold( + location, indexStrided, kernelUpperLength); + Value upperBound = + b.createOrFold(location, upperBoundMinusOne, cstOne); + Value upperBoundClamped = + b.createOrFold(location, upperBound, inputUpperBound); + + Value lowerBound = b.createOrFold(location, indexStrided, + kernelLowerLength); + Value lowerBoundClamped = + b.createOrFold(location, lowerBound, inputLowerBound); + Value clampedKernelSize = b.createOrFold( + location, upperBoundClamped, lowerBoundClamped); + if (i == 0) { - poolSize = IDim1_IDim0_Clamped; + poolSize = clampedKernelSize; } else { - poolSize = b.createOrFold(location, poolSize, - IDim1_IDim0_Clamped); + poolSize = + b.createOrFold(location, poolSize, clampedKernelSize); } } return poolSize; @@ -957,26 +1027,35 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; - // Creates the average pooling operation value when the - // count_include_pad parameter is equal to false. - static std::optional - createAvgPoolValueCountIncludePadFalseCase( - bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter, Value self, Value sumPool, - Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + // If the condition below is true, the divisor total must subtract the + // elements not counted (clamped divisor count). If false, the divisor + // is just the product of kernel dimensions. + static bool + doesAvgPoolDivisorNeedsClamping(bool ceilMode, bool countIncludePad, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts); + + // Creates the average pooling operation value with a clamped + // divisor. The clamped divisor is the product of kernel + // dimensions minus the elements not counted; e.g., padding + // and ceiling mode implicit padding. + static LogicalResult createAveragePoolValueWithClampedDivisor( + bool ceilMode, bool countIncludePad, OpTy op, + typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, + Value self, Value sumPool, Value outputTensor, Type resultType, + SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVector &indexingMapsAvg, SmallVector &iteratorTypesAvg); - // Creates the average pooling operation value when the - // count_include_pad parameter is equal to true. - static LogicalResult createAvgPoolValueCountIncludePadTrueCase( + // Creates the average pooling operation value with a + // regular divisor; i.e., the product of kernel dimensions. + static LogicalResult createAveragePoolValueWithRegularDivisor( OpTy op, typename OpTy::Adaptor &adaptor, ConversionPatternRewriter &rewriter, Value self, Value sumPool, Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &kernelDimSizes, SmallVector &indexingMapsAvg, SmallVector &iteratorTypesAvg); }; @@ -1040,27 +1119,59 @@ LogicalResult ConvertAtenAvgPoolOp::matchAndRewrite( SmallVector iteratorTypesAvg( Dim + 2, utils::IteratorType::parallel); - auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase( - countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor, - resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg, - iteratorTypesAvg); - if (divisorOpResult) - return *divisorOpResult; - - return createAvgPoolValueCountIncludePadTrueCase( - op, adaptor, rewriter, self, sumPool, outputTensor, resultType, - kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg); + if (doesAvgPoolDivisorNeedsClamping(ceilMode, countIncludePad, strideInts, + paddingInts)) { + return createAveragePoolValueWithClampedDivisor( + ceilMode, countIncludePad, op, adaptor, rewriter, self, sumPool, + outputTensor, resultType, kernelSizeIntValues, strideInts, paddingInts, + indexingMapsAvg, iteratorTypesAvg); + } else { + return createAveragePoolValueWithRegularDivisor( + op, adaptor, rewriter, self, sumPool, outputTensor, resultType, + kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg); + } +} - return success(); +template +bool ConvertAtenAvgPoolOp:: + doesAvgPoolDivisorNeedsClamping(bool ceilMode, bool countIncludePad, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts) { + // There are two ways to get the divisor clamped: through padding or + // ceiling mode. For the case when there is padding, the padding elements + // are omitted if count_include_pad == False (divisor is clamped). If + // there is no padding (padding == 0) then the count_include_pad value + // does not take effect. + // The divisor count can be clamped also through the ceil_mode. In this + // case, according to the Hout and Wout formula in this page: + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d, + // the ceil_mode will round up on the stride division. The round up + // will give an extra element that will go out of bounds which PyTorch + // adds zero padding in it. It also does not count the implicit zero + // padding elements in the divisor, and it is not controlled by the + // count_include_pad argument. + // But also note that if all strides are 1 there are not fractions to + // round up, hence there is no ceiling rounding and the window will + // not go out of bounds. For this case the divisor is just the + // product of kernel dimensions. + // Search for torch.nn.AvgPool2d E2E tests for coverage of these + // conditions. + + bool hasPadding = + !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); + bool allStridesUnitary = + llvm::all_of(strideInts, [](int64_t s) { return s == 1; }); + + return (!countIncludePad && hasPadding) || (ceilMode && !allStridesUnitary); } template -std::optional ConvertAtenAvgPoolOp:: - createAvgPoolValueCountIncludePadFalseCase( - bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter, Value self, Value sumPool, - Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, +LogicalResult ConvertAtenAvgPoolOp:: + createAveragePoolValueWithClampedDivisor( + bool ceilMode, bool countIncludePad, OpTy op, + typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, + Value self, Value sumPool, Value outputTensor, Type resultType, + SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVector &indexingMapsAvg, @@ -1069,11 +1180,6 @@ std::optional ConvertAtenAvgPoolOp:: constexpr int avgPoolDims = getAvgPoolNumOfDims(); - bool noPadding = llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); - if (countIncludePad || noPadding) { - // These cases are not handled here. - return std::nullopt; - } if (avgPoolDims < 1) { return rewriter.notifyMatchFailure( op, "Unexpected type. Only expected AtenAvgPool1dOp, AtenAvgPool2dOp, " @@ -1082,8 +1188,8 @@ std::optional ConvertAtenAvgPoolOp:: Type resultElementType = cast(resultType).getElementType(); - PoolSizeCalculator poolSizeCalculator(self, sumPool, rewriter, - loc); + PoolSizeCalculator poolSizeCalculator( + self, sumPool, countIncludePad, ceilMode, rewriter, loc); // AtenAvgPool2/3dOp has an optional divisor_override // attribute while AtenAvgPool1dOp does not. @@ -1104,7 +1210,7 @@ std::optional ConvertAtenAvgPoolOp:: [&](OpBuilder &b, Location loc, ValueRange args) { if (!poolSize) { poolSize = poolSizeCalculator.getPoolSize( - b, kernelSizeIntValues, strideInts, paddingInts); + b, kernelDimSizes, strideInts, paddingInts); } Value divisor = convertScalarToDtype(b, loc, poolSize, resultElementType); @@ -1122,21 +1228,21 @@ std::optional ConvertAtenAvgPoolOp:: template LogicalResult ConvertAtenAvgPoolOp:: - createAvgPoolValueCountIncludePadTrueCase( + createAveragePoolValueWithRegularDivisor( OpTy op, typename OpTy::Adaptor &adaptor, ConversionPatternRewriter &rewriter, Value self, Value sumPool, Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &kernelDimSizes, SmallVector &indexingMapsAvg, SmallVector &iteratorTypesAvg) { Location loc = op->getLoc(); Type resultElementType = cast(resultType).getElementType(); - Value divisor = kernelSizeIntValues[0]; - for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) { - divisor = rewriter.createOrFold(loc, divisor, - kernelSizeIntValues[i]); + Value divisor = kernelDimSizes[0]; + for (uint32_t i = 1; i < kernelDimSizes.size(); ++i) { + divisor = + rewriter.createOrFold(loc, divisor, kernelDimSizes[i]); } // Only average pooling 2D/3D have optional divisor override. if constexpr (!std::is_same()) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 42d7e01f9468..871e73ee9fd7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -650,6 +650,13 @@ "Aten_EmbeddingBagExample_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", + "AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic", + "AvgPool2dCeilPadNonUnitaryStrides_basic", + "AvgPool2dCeilNoPadStridedIncludePadding_basic", + "AvgPool2dCeilPaddingStridedIncludePadding_basic", + "AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "AvgPool2dDivisorOverrideModule_basic", "BernoulliTensorModule_basic", "BincountMinlengthModule_basic", @@ -2791,6 +2798,10 @@ "AvgPool2dSingleIntTupleParamsIncludePadModule_basic", "AvgPool2dSingleIntTupleParamsModule_basic", "AvgPool2dWithoutPadModule_basic", + "AvgPool1dNoPadCeilPadNotIncluded_basic", + "AvgPool1dPadCeilPadNotIncluded_basic", + "AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "BatchMlpLayerModule_basic", "BincountMinlengthModule_basic", "BincountModule_basic", @@ -3533,6 +3544,13 @@ "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", "AvgPool2dCeilModeTrueModule_basic", + "AvgPool1dNoPadCeilPadNotIncluded_basic", + "AvgPool1dPadCeilPadNotIncluded_basic", + "AvgPool2dCeilPaddingStridedIncludePadding_basic", + "AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic", + "AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic", + "AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "AvgPool2dDivisorOverrideModule_basic", "AvgPool2dFloatModule_basic", "AvgPool2dIntModule_basic", @@ -3939,6 +3957,9 @@ "AtenKthvalueFloat64Module_basic", "AtenKthvalueKeepDimModule_basic", "AtenKthvalueModule_basic", + "AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic", + "AvgPool2dCeilNoPadUnitaryStrides_basic", + "AvgPool2dCeilPadNonUnitaryStrides_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", "AvgPool3dStaticModule_basic", "Conv_Transpose1dModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 4a43b99033c1..ed8ec0faefa4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -2514,3 +2514,494 @@ def MaxUnpool3dModulePad0_basic(module, tu: TestUtils): output, indices = pool(input) module.forward(output, indices) + + +class AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa(torch.nn.Module): + # This test captures the torch-mlir issue reported here: + # https://github.com/llvm/torch-mlir/issues/4079 + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa()) +def AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilNoPadUnitaryStrides(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadUnitaryStrides()) +def AvgPool2dCeilNoPadUnitaryStrides_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPadNonUnitaryStrides(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[1, 1], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilPadNonUnitaryStrides()) +def AvgPool2dCeilPadNonUnitaryStrides_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilNoPadStridedIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[0, 0], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadStridedIncludePadding()) +def AvgPool2dCeilNoPadStridedIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilNoPadUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilNoPadUnitaryStrideIncludePadding() +) +def AvgPool2dCeilNoPadUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse() +) +def AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dFloorNoPadUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dFloorNoPadUnitaryStrideIncludePadding() +) +def AvgPool2dFloorNoPadUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dFloorPaddingUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dFloorPaddingUnitaryStrideIncludePadding() +) +def AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilPaddingUnitaryStrideIncludePadding() +) +def AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingStridedIncludePadding(torch.nn.Module): + # Note that in this case the kernel window center will go into the padding. + # When this happens the padding elements are counted in the divisor, but + # the out of bound elements from the ceiling are not counted + # (i.e., clamped from the divisor count). + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[1, 1], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilPaddingStridedIncludePadding()) +def AvgPool2dCeilPaddingStridedIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded(torch.nn.Module): + # Different sizes used for each kernel and stride.dimensions. No padding. + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 2], + stride=[2, 3], + padding=[0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded() +) +def AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, low=-1)) + + +class AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded(torch.nn.Module): + # Different sizes used for each kernel, stride, and padding.dimensions. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 4], + stride=[2, 3], + padding=[1, 2], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded() +) +def AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, low=-1)) + + +class AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded(torch.nn.Module): + # 3D version of AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[3, 2, 4], + stride=[3, 2, 5], + padding=[0, 0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 5, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded() +) +def AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 5, 7, low=-1)) + + +class AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded(torch.nn.Module): + # 3-D version of AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[3, 4, 7], + stride=[2, 3, 4], + padding=[1, 2, 3], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded() +) +def AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, 7, low=-1)) + + +class AvgPool1dNoPadCeilPadNotIncluded(torch.nn.Module): + # 1D version of AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool1d( + kernel_size=[2], + stride=[2], + padding=[1], + ceil_mode=True, + count_include_pad=False, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 5], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool1dNoPadCeilPadNotIncluded()) +def AvgPool1dNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 5, low=-1)) + + +class AvgPool1dPadCeilPadNotIncluded(torch.nn.Module): + # 1-D version of AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool1d( + kernel_size=[2], + stride=[2], + padding=[1], + ceil_mode=True, + count_include_pad=False, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool1dPadCeilPadNotIncluded()) +def AvgPool1dPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, low=-1)) diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index c065e624efa9..91043b83728a 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -126,7 +126,7 @@ func.func @forward_avg_pool2d_countincludepad_false(%arg0: !torch.vtensor<[1,3,6 // CHECK: linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<[1, 2]> : vector<2xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x3x64x58xf32>, tensor<4x5xf32>) outs(%[[OUT1:.*]] : tensor<1x3x61x27xf32>) -> tensor<1x3x61x27xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x61x27xf32>) outs(%[[OUT2:.*]] : tensor<1x3x61x27xf32>) // CHECK-NEXT: ^bb0(%[[BIIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-4: arith.minsi + // CHECK-COUNT-1: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x3x61x27xf32> @@ -179,7 +179,7 @@ func.func @forward_avg_pool3dd_countincludepad_false(%arg0: !torch.vtensor<[1,3, // CHECK: linalg.pooling_ndhwc_sum {dilations = dense<1> : vector<3xi64>, strides = dense<[1, 2, 1]> : vector<3xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x7x66x58x3xf32>, tensor<4x5x5xf32>) outs(%[[OUT1:.*]] : tensor<1x4x31x54x3xf32>) -> tensor<1x4x31x54x3xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x4x31x54xf32>) outs(%[[OUT2:.*]] : tensor<1x3x4x31x54xf32>) // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-6: arith.minsi + // CHECK-COUNT-3: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x3x4x31x54xf32> @@ -221,7 +221,7 @@ func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512 // CHECK: linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%[[IN1:.*]], %[[IN2:.*]] : tensor<1x512x12xf32>, tensor<1xf32>) outs(%[[OUT1:.*]] : tensor<1x512x12xf32>) -> tensor<1x512x12xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[IN3:.*]] : tensor<1x512x12xf32>) outs(%[[OUT2:.*]] : tensor<1x512x12xf32> // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-2: arith.minsi + // CHECK-COUNT-1: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x512x12xf32> @@ -233,3 +233,31 @@ func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512 %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,12],f32> return %3 : !torch.vtensor<[1,512,12],f32> } + +// CHECK-LABEL: func @forward_avgpool_2d_ceil +func.func @forward_avgpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> { + // CHECK: %[[POOL_OUT:.*]] = linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%[[PADDED_IN:.*]], %[[KERNEL_IN:.*]] : tensor<1x1x6x6xf32>, tensor<3x3xf32>) outs(%[[OUT1:.*]] : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> + // CHECK: linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL_OUT]] : tensor<1x1x2x2xf32>) outs(%[[GEN_OUT:.*]] : tensor<1x1x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-3: arith.muli + // CHECK-COUNT-1: arith.sitofp + // CHECK-COUNT-1: arith.divf + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x1x2x2xf32> + %int3 = torch.constant.int 3 + %int3_0 = torch.constant.int 3 + %int0 = torch.constant.int 0 + %int0_1 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int2_2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int1_3 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int3, %int3_0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int0, %int0_1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int2_2, %int1, %int1_3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %true = torch.constant.bool true + %false = torch.constant.bool false + %none = torch.constant.none + %3 = torch.aten.avg_pool2d %arg0, %0, %2, %1, %true, %false, %none : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,2,2],f32> + return %3 : !torch.vtensor<[1,1,2,2],f32> +} From 0b8d27e7863bd882385de63dcce0b33a3f2a3b43 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Fri, 2 May 2025 16:41:04 -0400 Subject: [PATCH 7/8] Addressing round 2 of Vivek's feedback. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 8 ++++---- .../pt1/python/torch_mlir_e2e_test/test_suite/pooling.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index ae90cea90fff..5635dd9f0524 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -1126,11 +1126,11 @@ LogicalResult ConvertAtenAvgPoolOp::matchAndRewrite( ceilMode, countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor, resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg, iteratorTypesAvg); - } else { - return createAveragePoolValueWithRegularDivisor( - op, adaptor, rewriter, self, sumPool, outputTensor, resultType, - kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg); } + + return createAveragePoolValueWithRegularDivisor( + op, adaptor, rewriter, self, sumPool, outputTensor, resultType, + kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg); } template diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index ed8ec0faefa4..b27bef3ab1f2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -2516,7 +2516,7 @@ def MaxUnpool3dModulePad0_basic(module, tu: TestUtils): module.forward(output, indices) -class AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa(torch.nn.Module): +class AvgPool2dCeilNoPadNonUnitaryStrides(torch.nn.Module): # This test captures the torch-mlir issue reported here: # https://github.com/llvm/torch-mlir/issues/4079 @@ -2542,8 +2542,8 @@ def forward(self, x): return self.ap2d(x) -@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa()) -def AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadNonUnitaryStrides()) +def AvgPool2dCeilNoPadNonUnitaryStrides_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 4, 4, low=-1)) From 2b24438a112afeeaac2958c2df02057d74ce225a Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Fri, 2 May 2025 21:32:36 -0400 Subject: [PATCH 8/8] Bring back PyTorch based average pooling divisor computation after a couple of corrections. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 131 +++++++---------------- 1 file changed, 36 insertions(+), 95 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 5635dd9f0524..04a4cad30224 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -857,8 +857,7 @@ namespace { template class PoolSizeCalculator { public: PoolSizeCalculator(Value self, Value sumPool, bool countIncludePad, - bool ceilMode, ConversionPatternRewriter &rewriter, - Location loc); + ConversionPatternRewriter &rewriter, Location loc); // The algorithm for computing the divisor with // count_include_pad equal is mainly based on pytorch @@ -876,16 +875,15 @@ template class PoolSizeCalculator { Value InputSpatialDimSizes[NumOfDims]; Location location; bool isCountIncludePad; - bool isCeilMode; }; } // namespace template PoolSizeCalculator::PoolSizeCalculator( - Value self, Value sumPool, bool countIncludePad, bool ceilMode, + Value self, Value sumPool, bool countIncludePad, ConversionPatternRewriter &rewriter, Location loc) - : location(loc), isCountIncludePad(countIncludePad), isCeilMode(ceilMode) { + : location(loc), isCountIncludePad(countIncludePad) { auto selfType = cast(self.getType()); const int64_t selfRank = selfType.getRank(); RankedTensorType sumPoolType = cast(sumPool.getType()); @@ -910,40 +908,14 @@ Value PoolSizeCalculator::getPoolSize( Value cstZero = b.createOrFold(location, b.getI64IntegerAttr(0)); - Value cstOne = - b.createOrFold(location, b.getI64IntegerAttr(1)); - Value cstTwo = - b.createOrFold(location, b.getI64IntegerAttr(2)); for (int i = 0; i < NumOfDims; ++i) { - // The following code computes the clamped kernel size used to compute - // the divisor of the average pooling operator. Here is the formula that - // it represents: - // - // indexStartOffset = ceil((kernelSize - 1)/2) - padding - // - // clampedKernelSize = - // min(outIntIndex * stride + indexStartOffset + floor((kernelSize - 1)/2) - // + 1, - // InputSpatialDimSize + padding) - - // max(outIntIndex * stride + indexStartOffset - ceil((kernelSize - 1)/2), - // -padding) - // - // The outIntIndex is the current iteration value coming from the - // linalg.generic op and it represents the center of the kernel window. - // The padding above becomes zero if count_include_pad is false. - // The kernelSize - 1 is used to subtract the center element of the kernel - // from the kernel size before dividing by two. Note that PyTorch even - // kernel dimensions are biased to the lower side of the dimension. Hence - // the lower length uses ceiling. While the upper length uses floor. - // - // If count_include_pad is true, in most cases the divisor is just the - // product of kernel dimensions. But we still need this logic for the - // case in which the ceiling mode is true since the kernel window - // center can go into the padding outside of the input tensor. This - // introduces an implicit padding that is not controlled by the - // count_include_pad parameter. See the - // AvgPool2dCeilPaddingStridedIncludePadding E2E test for details. + // See the link below for the PyTorch implementation where this is + // derived from: + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + // Dim below stands for spatial dimension. Prior to the February 2025 + // change, these variables used "height" and "width" (or "h" and "w") + // in these intermediate variables instead of "Dim". // The average pool properties of kernel size, strides, and padding are // stored in the reverse order of the input tensor dimensions. The @@ -951,68 +923,37 @@ Value PoolSizeCalculator::getPoolSize( // corresponds to the current spatial dimension. int avgPoolPropIdx = NumOfDims - i - 1; - Value padding = b.createOrFold( - location, b.getI64IntegerAttr(paddingInts[avgPoolPropIdx])); - Value InputSpatialDimSize = - castIndexToInt64(b, location, InputSpatialDimSizes[i]); - // Subtract center element from kernel size before division by two. - Value kernelSizeMinusOne = b.createOrFold( - location, kernelDimSizes[avgPoolPropIdx], cstOne); - // PyTorch even kernel dimensions are biased to the lower side of the - // dimension. Hence the lower lenght uses ceiling. - Value kernelLowerLength = b.createOrFold( - location, kernelSizeMinusOne, cstTwo); - // While the upper length uses floor. - Value kernelUpperLength = b.createOrFold( - location, kernelSizeMinusOne, cstTwo); - - // The more padding the closest we can read from the lower bound of - // the input tensor. - Value indexStartOffset = - b.createOrFold(location, kernelLowerLength, padding); - - Value outIndex = + Value IndexODim = b.create(location, /*value=*/SumPoolTypeDimIndex[i]); - - Value outIntIndex = castIndexToInt64(b, location, outIndex); - - Value stride = b.createOrFold( + Value ODim = castIndexToInt64(b, location, IndexODim); + Value DDim = b.createOrFold( location, b.getI64IntegerAttr(strideInts[avgPoolPropIdx])); - - Value indexStrided = b.createOrFold( - location, b.createOrFold(location, outIntIndex, stride), - indexStartOffset); - - Value inputUpperBound = isCountIncludePad - ? b.createOrFold( - location, InputSpatialDimSize, padding) - : InputSpatialDimSize; - - Value inputLowerBound = - isCountIncludePad - ? b.createOrFold(location, cstZero, padding) - : cstZero; - - Value upperBoundMinusOne = b.createOrFold( - location, indexStrided, kernelUpperLength); - Value upperBound = - b.createOrFold(location, upperBoundMinusOne, cstOne); - Value upperBoundClamped = - b.createOrFold(location, upperBound, inputUpperBound); - - Value lowerBound = b.createOrFold(location, indexStrided, - kernelLowerLength); - Value lowerBoundClamped = - b.createOrFold(location, lowerBound, inputLowerBound); - Value clampedKernelSize = b.createOrFold( - location, upperBoundClamped, lowerBoundClamped); - + Value PadDim = b.createOrFold( + location, b.getI64IntegerAttr(paddingInts[avgPoolPropIdx])); + Value ODimDDim = b.createOrFold(location, ODim, DDim); + Value IDim0 = b.createOrFold(location, ODimDDim, PadDim); + Value IDim = castIndexToInt64(b, location, InputSpatialDimSizes[i]); + Value IDim0KDim = b.createOrFold( + location, IDim0, kernelDimSizes[avgPoolPropIdx]); + Value IDimPadDim = b.createOrFold(location, IDim, PadDim); + Value IDim1 = + b.createOrFold(location, IDim0KDim, IDimPadDim); + + Value IDim0Clamped = + b.createOrFold(location, IDim0, cstZero); + Value IDim1Clamped = b.createOrFold(location, IDim1, IDim); + Value IDim1_IDim0_Clamped = + b.createOrFold(location, IDim1Clamped, IDim0Clamped); + + Value poolSizeDim = + !isCountIncludePad + ? IDim1_IDim0_Clamped + : b.createOrFold(location, IDim1, IDim0); if (i == 0) { - poolSize = clampedKernelSize; + poolSize = poolSizeDim; } else { - poolSize = - b.createOrFold(location, poolSize, clampedKernelSize); + poolSize = b.createOrFold(location, poolSize, poolSizeDim); } } return poolSize; @@ -1190,7 +1131,7 @@ LogicalResult ConvertAtenAvgPoolOp:: Type resultElementType = cast(resultType).getElementType(); PoolSizeCalculator poolSizeCalculator( - self, sumPool, countIncludePad, ceilMode, rewriter, loc); + self, sumPool, countIncludePad, rewriter, loc); // AtenAvgPool2/3dOp has an optional divisor_override // attribute while AtenAvgPool1dOp does not.