-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] TosaToLinalgNamed: Lower unsigned tosa.max_pool2d #123290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR allows to lower unsigned `tosa.max_pool2d` to linalg. ``` // CHECK-LABEL: @max_pool_ui8 func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> { // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8> // CHECK: arith.constant 0 // CHECK: linalg.pooling_nhwc_max_unsigned {{.*}} : (tensor<1x4x32x62xi8>) -> tensor<1x4x32x62xi8> // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8> %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> return %0 : tensor<1x4x32x62xui8> } ``` It does this by - converting the MaxPool2dConverter from OpRewriterPattern to OpConversion Pattern - adjusting the padding value to the the minimum unsigned value when the max_pool is unsigned - lowering to `linalg.pooling_nhwc_max_unsigned` (which uses `arith.maxui`) when the max_pool is unsigned
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Matthias Gehre (mgehre-amd) ChangesThis PR allows to lower unsigned
It does this by
Full diff: https://github.com/llvm/llvm-project/pull/123290.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 1822016fc88fe6..a1eb22eba69877 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -52,7 +52,8 @@ void populateTosaToLinalgConversionPatterns(const TypeConverter &converter,
/// Populates conversion passes from TOSA dialect to Linalg named operations.
void populateTosaToLinalgNamedConversionPatterns(
- RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
+ const TypeConverter &converter, RewritePatternSet *patterns,
+ const TosaToLinalgNamedOptions &options);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index d537aef5791031..b7af37d293ac1c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -695,17 +695,18 @@ class FullyConnectedConverter
}
};
-class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
+class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
public:
- using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
+ using OpConversionPattern::OpConversionPattern;
// Compute the dynamic output sizes of the maxpool operation.
static SmallVector<Value>
- computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
+ computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) {
TensorType resultTy = op.getType();
Location loc = op.getLoc();
- TypedValue<TensorType> input = op.getInput();
+ Value input = adaptor.getInput();
ArrayRef<int64_t> kernel = op.getKernel();
ArrayRef<int64_t> pad = op.getPad();
ArrayRef<int64_t> stride = op.getStride();
@@ -744,16 +745,22 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
return dynamicDims;
}
- LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
- TypedValue<TensorType> input = op.getInput();
- ShapedType inputTy = input.getType();
+ Value input = adaptor.getInput();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
- ShapedType resultTy = op.getType();
+ bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
+ ShapedType resultTy =
+ cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
+ if (!resultTy)
+ return rewriter.notifyMatchFailure(op, "failed to convert type");
Type resultETy = inputTy.getElementType();
- SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
+ SmallVector<Value> dynamicDims =
+ computeDynamicOutputSizes(op, adaptor, rewriter);
// Determine what the initial value needs to be for the max pool op.
TypedAttr initialAttr;
@@ -762,7 +769,10 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
resultETy, APFloat::getLargest(
cast<FloatType>(resultETy).getFloatSemantics(), true));
- if (isa<IntegerType>(resultETy))
+ else if (isUnsigned)
+ initialAttr = rewriter.getIntegerAttr(
+ resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
+ else if (isa<IntegerType>(resultETy))
initialAttr = rewriter.getIntegerAttr(
resultETy,
APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
@@ -798,9 +808,15 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
Value fakeWindowDims =
rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
- rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
- op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
- filledEmptyTensor, strideAttr, dilationAttr);
+ if (isUnsigned) {
+ rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
+ op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
+ filledEmptyTensor, strideAttr, dilationAttr);
+ } else {
+ rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
+ op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
+ filledEmptyTensor, strideAttr, dilationAttr);
+ }
return success();
}
};
@@ -1070,7 +1086,8 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
} // namespace
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
- RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
+ const TypeConverter &converter, RewritePatternSet *patterns,
+ const TosaToLinalgNamedOptions &options) {
if (options.preferConv2DKernelLayoutHWCF) {
patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
linalg::Conv2DNhwcHwcfQOp>>(
@@ -1085,10 +1102,13 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
DepthwiseConvConverter,
MatMulConverter,
- MaxPool2dConverter,
AvgPool2dConverter,
FullyConnectedConverter,
TransposeConverter
>(patterns->getContext());
+
+ patterns->add<
+ MaxPool2dConverter
+ >(converter, patterns->getContext());
// clang-format on
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 096969391e51b9..7d943b3779fb02 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -47,6 +47,9 @@ struct TosaToLinalgNamed
}
void runOnOperation() override {
+ TypeConverter converter;
+ tosa::populateTosaTypeConversion(converter);
+
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, tosa::TosaDialect,
@@ -67,7 +70,8 @@ struct TosaToLinalgNamed
FunctionOpInterface func = getOperation();
TosaToLinalgNamedOptions options;
options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF;
- tosa::populateTosaToLinalgNamedConversionPatterns(&patterns, options);
+ tosa::populateTosaToLinalgNamedConversionPatterns(converter, &patterns,
+ options);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 453a8610e7169a..5eeaebb384e408 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -200,6 +200,19 @@ func.func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
return
}
+// CHECK-LABEL: @max_pool_ui8
+func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
+ // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
+ // CHECK: arith.constant 0
+ // CHECK: linalg.pooling_nhwc_max_unsigned
+ // CHECK-SAME: ins({{.*}} : tensor<1x6x34x62xi8>, tensor<3x3xi8>)
+ // CHECK-SAME: outs({{.*}} : tensor<1x4x32x62xi8>)
+ // CHECK-SAME: -> tensor<1x4x32x62xi8>
+ // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
+ %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
+ return %0 : tensor<1x4x32x62xui8>
+}
+
// CHECK-LABEL: @max_pool_i16
func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () {
// CHECK: arith.constant -32768
|
I wonder is unsigned |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good on my side and I do understand the need!
@eric-k256 @sjarus could you please have a look as well as not sure on the spec aspect?
We had previously agreed (in a TOSA community meeting) to allow more data types in MLIR where useful. The TOSA validation pass can be used to reject them if strict adherence to the spec is needed. |
Makes sense @mgehre-amd and as mentioned above understand the need. Will proceed with approving this. |
This PR allows to lower unsigned
tosa.max_pool2d
to linalg.It does this by
linalg.pooling_nhwc_max_unsigned
(which usesarith.maxui
) when the max_pool is unsigned