Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
Type inputElemTy, Type outputElemTy,
ArrayRef<int64_t> weightShape);

// Emit an explicit zero-valued `tosa.pad` around an NHWC tensor so that later
// avg_pool lowering can run with `pad = 0`. `padExtents` is ordered as
// {top, bottom, left, right}. Returns the padded tensor value.
Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
Operation *op, Value inputNHWC,
ArrayRef<int64_t> padExtents);

} // namespace tosa
} // namespace mlir

Expand Down
74 changes: 56 additions & 18 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6075,7 +6075,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
DenseI64ArrayAttr &pad) {
DenseI64ArrayAttr &pad,
SmallVectorImpl<int64_t> *explicitNHWCPad = nullptr) {

RankedTensorType inputTy = cast<RankedTensorType>(inputXchw.getType());
if (!inputTy)
Expand Down Expand Up @@ -6115,21 +6116,43 @@ static LogicalResult getOutputTypeAndPoolingParameters(

if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answer (SWA) when the `count_include_pad` value is `true.`
//
// Note: We need to check for `count_include_pad` only when the `padding`
// value is non-zero.
// When count_include_pad=true with non-zero padding, we will materialize an
// explicit pad after transposing to NHWC. Track the padding extents and
// zero out the TOSA op padding so the divisor matches the full kernel size.
bool countIncludePad;
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
(!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||

countIncludePad)) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
"`count_include_pad` value should be `False`.");
if (!explicitNHWCPad)
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
"`count_include_pad` value should be `False`.");

// Remember the spatial padding so we can emit an NHWC tosa.pad right
// after the transpose.
explicitNHWCPad->assign(
{paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]});

auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
if (ShapedType::isDynamic(dim))
return ShapedType::kDynamic;
return dim + before + after;
};

// Update the logical input type used for shape computations to include
// the extra zeros supplied by the explicit pad.
SmallVector<int64_t> paddedShape(inputTy.getShape().begin(),
inputTy.getShape().end());
// Height stored at rank-2, width at rank-1 for NCHW shapes.
paddedShape[inputRank - 2] =
addPad(paddedShape[inputRank - 2], paddingInts[0], paddingInts[0]);
paddedShape[inputRank - 1] =
addPad(paddedShape[inputRank - 1], paddingInts[1], paddingInts[1]);
inputTy = RankedTensorType::get(paddedShape, inputTy.getElementType());

paddingInts.assign(/*Count=*/2, /*Value=*/0);
}
}

Expand Down Expand Up @@ -6275,15 +6298,23 @@ class ConvertAtenAvgPool2dOp
}

SmallVector<int64_t, 2> dilationArray{1, 1};
SmallVector<int64_t, 4> explicitNHWCPad;
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
tosa::AvgPool2dOp>(
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad,
&explicitNHWCPad)))
return rewriter.notifyMatchFailure(
op, "invalid pooling parameters or input type");

// Transpose to xHWC
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
transposePoolingInputToHwc(op, rewriter, self);
Value transposed =
ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
transposePoolingInputToHwc(op, rewriter, self);

if (!explicitNHWCPad.empty())
transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op,
transposed, explicitNHWCPad);

input = transposed;

return success();
}
Expand Down Expand Up @@ -6328,16 +6359,23 @@ class ConvertAtenAvgPool1dOp
.getResult();

SmallVector<int64_t, 2> dilationArray{1, 1};
SmallVector<int64_t, 4> explicitNHWCPad;
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
tosa::AvgPool2dOp>(
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
pad)))
pad, &explicitNHWCPad)))
return rewriter.notifyMatchFailure(
op, "invalid pooling parameters or input type");

// Transpose to xHWC
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
Value transposed =
ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
transposePoolingInputToHwc(op, rewriter, reshapedSelf);

if (!explicitNHWCPad.empty())
transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op,
transposed, explicitNHWCPad);

input = transposed;

return success();
}
Expand Down
37 changes: 37 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,5 +595,42 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
}
}

Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
Operation *op, Value inputNHWC,
ArrayRef<int64_t> padExtents) {
assert(padExtents.size() == 4 && "expected [top, bottom, left, right]");

if (llvm::all_of(padExtents, [](int64_t v) { return v == 0; }))
return inputNHWC;

SmallVector<int64_t, 8> nhwcPadding = {
0, 0, padExtents[0], padExtents[1], padExtents[2], padExtents[3], 0, 0};
Value nhwcPadShape = tosa::getTosaConstShape(rewriter, loc, nhwcPadding);

auto inputTy = cast<RankedTensorType>(inputNHWC.getType());
SmallVector<int64_t, 4> resultShape(inputTy.getShape().begin(),
inputTy.getShape().end());
auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
if (ShapedType::isDynamic(dim))
return ShapedType::kDynamic;
return dim + before + after;
};
resultShape[1] = addPad(resultShape[1], padExtents[0], padExtents[1]);
resultShape[2] = addPad(resultShape[2], padExtents[2], padExtents[3]);

auto resultTy = RankedTensorType::get(resultShape, inputTy.getElementType());

Type elemTy = inputTy.getElementType();
Value padConst;
if (isa<mlir::FloatType>(elemTy)) {
padConst = *getConstTensor<float>(rewriter, op, {0.0f}, {1}, elemTy);
} else {
padConst = *getConstTensor<int32_t>(rewriter, op, {0}, {1}, elemTy);
}

return rewriter.create<tosa::PadOp>(loc, resultTy, inputNHWC, nhwcPadShape,
padConst);
}

} // namespace tosa
} // namespace mlir
112 changes: 79 additions & 33 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2265,24 +2265,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso

// -----

func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false= torch.constant.bool false
%count_include_pad = torch.constant.bool true
%divisor_override = torch.constant.none

%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
return %3 : !torch.vtensor<[1,192,35,35],f32>
}

// -----

func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
Expand Down Expand Up @@ -2802,21 +2784,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to

// -----

func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%count_include_pad = torch.constant.bool true
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
Expand Down Expand Up @@ -4328,3 +4295,82 @@ func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch
%0 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[2,4],f16>, !torch.vtensor<[3,4],f16>, !torch.vtensor<[3],f16> -> !torch.vtensor<[2,3],f16>
return %0 : !torch.vtensor<[2,3],f16>
}

// -----
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
// CHECK: %[[VAL_7:.*]] = torch.constant.none
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x35x35xf32>) -> tensor<1x35x35x192xf32>
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x35x35x192xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x37x37x192xf32>
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
// CHECK: }
func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false= torch.constant.bool false
%count_include_pad = torch.constant.bool true
%divisor_override = torch.constant.none

%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
return %3 : !torch.vtensor<[1,192,35,35],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
// CHECK: %[[VAL_5:.*]] = torch.constant.bool true
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_10]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x10x1xf32>) -> tensor<1x10x1x512xf32>
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x10x1x512xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x12x1x512xf32>
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
// CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
// CHECK: }
func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%count_include_pad = torch.constant.bool true
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}
Loading