From c99c982e2ff39d8c4be2df54af3deee0469ad42b Mon Sep 17 00:00:00 2001 From: Vitalii Shutov Date: Mon, 14 Jul 2025 18:22:00 +0100 Subject: [PATCH] [TOSA] Add legalization for avg_pool2d Before this patch, the `avg_pool2d` and `avg_pool1d` legalizations lacked support for pooling with count_include_pad=True. This patch introduces that support. Signed-off-by: Vitalii Shutov Change-Id: I73fa26a58379e2c021929ade81c983ff91c59667 --- .../TorchToTosa/TosaLegalizeUtils.h | 7 ++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 74 +++++++++--- .../TorchToTosa/TosaLegalizeUtils.cpp | 37 ++++++ test/Conversion/TorchToTosa/basic.mlir | 112 ++++++++++++------ 4 files changed, 179 insertions(+), 51 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index be1ea0c3221a..2b2f123e1ccd 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -106,6 +106,13 @@ FailureOr getConvBiasForNoneType(Operation *op, Type inputElemTy, Type outputElemTy, ArrayRef weightShape); +// Emit an explicit zero-valued `tosa.pad` around an NHWC tensor so that later +// avg_pool lowering can run with `pad = 0`. `padExtents` is ordered as +// {top, bottom, left, right}. Returns the padded tensor value. +Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter, + Operation *op, Value inputNHWC, + ArrayRef padExtents); + } // namespace tosa } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 8f89567df6f7..05c1c261c40a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -6075,7 +6075,8 @@ static LogicalResult getOutputTypeAndPoolingParameters( AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw, SmallVectorImpl &dilationArray, Type &outputTy, DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride, - DenseI64ArrayAttr &pad) { + DenseI64ArrayAttr &pad, + SmallVectorImpl *explicitNHWCPad = nullptr) { RankedTensorType inputTy = cast(inputXchw.getType()); if (!inputTy) @@ -6115,21 +6116,43 @@ static LogicalResult getOutputTypeAndPoolingParameters( if constexpr (std::is_same() || std::is_same()) { - // Currently, we can not represent `count_include_pad` with the existing - // TOSA AvgPool2d specification. Without the below check, we produce silent - // wrong answer (SWA) when the `count_include_pad` value is `true.` - // - // Note: We need to check for `count_include_pad` only when the `padding` - // value is non-zero. + // When count_include_pad=true with non-zero padding, we will materialize an + // explicit pad after transposing to NHWC. Track the padding extents and + // zero out the TOSA op padding so the divisor matches the full kernel size. bool countIncludePad; if ((paddingInts[0] != 0 || paddingInts[1] != 0) && (!matchPattern(op.getCountIncludePad(), m_TorchConstantBool(&countIncludePad)) || countIncludePad)) { - return rewriter.notifyMatchFailure( - op, "Unsupported `count_include_pad` value, for tosa AvgPool " - "`count_include_pad` value should be `False`."); + if (!explicitNHWCPad) + return rewriter.notifyMatchFailure( + op, "Unsupported `count_include_pad` value, for tosa AvgPool " + "`count_include_pad` value should be `False`."); + + // Remember the spatial padding so we can emit an NHWC tosa.pad right + // after the transpose. + explicitNHWCPad->assign( + {paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]}); + + auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t { + if (ShapedType::isDynamic(dim)) + return ShapedType::kDynamic; + return dim + before + after; + }; + + // Update the logical input type used for shape computations to include + // the extra zeros supplied by the explicit pad. + SmallVector paddedShape(inputTy.getShape().begin(), + inputTy.getShape().end()); + // Height stored at rank-2, width at rank-1 for NCHW shapes. + paddedShape[inputRank - 2] = + addPad(paddedShape[inputRank - 2], paddingInts[0], paddingInts[0]); + paddedShape[inputRank - 1] = + addPad(paddedShape[inputRank - 1], paddingInts[1], paddingInts[1]); + inputTy = RankedTensorType::get(paddedShape, inputTy.getElementType()); + + paddingInts.assign(/*Count=*/2, /*Value=*/0); } } @@ -6275,15 +6298,23 @@ class ConvertAtenAvgPool2dOp } SmallVector dilationArray{1, 1}; + SmallVector explicitNHWCPad; if (failed(getOutputTypeAndPoolingParameters( - op, rewriter, self, dilationArray, outputTy, kernel, stride, pad))) + op, rewriter, self, dilationArray, outputTy, kernel, stride, pad, + &explicitNHWCPad))) return rewriter.notifyMatchFailure( op, "invalid pooling parameters or input type"); - // Transpose to xHWC - input = ConvertAtenPoolingBaseOp:: - transposePoolingInputToHwc(op, rewriter, self); + Value transposed = + ConvertAtenPoolingBaseOp:: + transposePoolingInputToHwc(op, rewriter, self); + + if (!explicitNHWCPad.empty()) + transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op, + transposed, explicitNHWCPad); + + input = transposed; return success(); } @@ -6328,16 +6359,23 @@ class ConvertAtenAvgPool1dOp .getResult(); SmallVector dilationArray{1, 1}; + SmallVector explicitNHWCPad; if (failed(getOutputTypeAndPoolingParameters( op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride, - pad))) + pad, &explicitNHWCPad))) return rewriter.notifyMatchFailure( op, "invalid pooling parameters or input type"); - // Transpose to xHWC - input = ConvertAtenPoolingBaseOp:: - transposePoolingInputToHwc(op, rewriter, reshapedSelf); + Value transposed = + ConvertAtenPoolingBaseOp:: + transposePoolingInputToHwc(op, rewriter, reshapedSelf); + + if (!explicitNHWCPad.empty()) + transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op, + transposed, explicitNHWCPad); + + input = transposed; return success(); } diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 727a4ba5d5e5..a21d54f360ce 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -595,5 +595,42 @@ FailureOr getConvBiasForNoneType(Operation *op, } } +Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter, + Operation *op, Value inputNHWC, + ArrayRef padExtents) { + assert(padExtents.size() == 4 && "expected [top, bottom, left, right]"); + + if (llvm::all_of(padExtents, [](int64_t v) { return v == 0; })) + return inputNHWC; + + SmallVector nhwcPadding = { + 0, 0, padExtents[0], padExtents[1], padExtents[2], padExtents[3], 0, 0}; + Value nhwcPadShape = tosa::getTosaConstShape(rewriter, loc, nhwcPadding); + + auto inputTy = cast(inputNHWC.getType()); + SmallVector resultShape(inputTy.getShape().begin(), + inputTy.getShape().end()); + auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t { + if (ShapedType::isDynamic(dim)) + return ShapedType::kDynamic; + return dim + before + after; + }; + resultShape[1] = addPad(resultShape[1], padExtents[0], padExtents[1]); + resultShape[2] = addPad(resultShape[2], padExtents[2], padExtents[3]); + + auto resultTy = RankedTensorType::get(resultShape, inputTy.getElementType()); + + Type elemTy = inputTy.getElementType(); + Value padConst; + if (isa(elemTy)) { + padConst = *getConstTensor(rewriter, op, {0.0f}, {1}, elemTy); + } else { + padConst = *getConstTensor(rewriter, op, {0}, {1}, elemTy); + } + + return rewriter.create(loc, resultTy, inputNHWC, nhwcPadShape, + padConst); +} + } // namespace tosa } // namespace mlir diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 3d2e85acee4a..255e39405e68 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2265,24 +2265,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso // ----- -func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int3 = torch.constant.int 3 - %false= torch.constant.bool false - %count_include_pad = torch.constant.bool true - %divisor_override = torch.constant.none - - %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}} - %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32> - return %3 : !torch.vtensor<[1,192,35,35],f32> -} - -// ----- - func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 @@ -2802,21 +2784,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to // ----- -func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { - %int1 = torch.constant.int 1 - %int3 = torch.constant.int 3 - %false = torch.constant.bool false - %count_include_pad = torch.constant.bool true - %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list - %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list - %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list - // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}} - %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> - return %3 : !torch.vtensor<[1,512,10],f32> -} - -// ----- - // CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32> @@ -4328,3 +4295,82 @@ func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch %0 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[2,4],f16>, !torch.vtensor<[3,4],f16>, !torch.vtensor<[3],f16> -> !torch.vtensor<[2,3],f16> return %0 : !torch.vtensor<[2,3],f16> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_5:.*]] = torch.constant.bool false +// CHECK: %[[VAL_6:.*]] = torch.constant.bool true +// CHECK: %[[VAL_7:.*]] = torch.constant.none +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x192x35x35xf32>) -> tensor<1x35x35x192xf32> +// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x35x35x192xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x37x37x192xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32> +// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32> +// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32> +// CHECK: } +func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false= torch.constant.bool false + %count_include_pad = torch.constant.bool true + %divisor_override = torch.constant.none + + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32> + return %3 : !torch.vtensor<[1,192,35,35],f32> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_4:.*]] = torch.constant.bool false +// CHECK: %[[VAL_5:.*]] = torch.constant.bool true +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_10]] {perms = array} : (tensor<1x512x10x1xf32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x10x1x512xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x12x1x512xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32> +// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32> +// CHECK: } +func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false = torch.constant.bool false + %count_include_pad = torch.constant.bool true + %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> + return %3 : !torch.vtensor<[1,512,10],f32> +}