Skip to content

[MLIR] TosaToLinalgNamed: Lower unsigned tosa.max_pool2d #123290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2025

Conversation

mgehre-amd
Copy link
Contributor

This PR allows to lower unsigned tosa.max_pool2d to linalg.

// CHECK-LABEL: @max_pool_ui8
func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
  // CHECK: arith.constant 0
  // CHECK: linalg.pooling_nhwc_max_unsigned {{.*}} : (tensor<1x4x32x62xi8>) -> tensor<1x4x32x62xi8>
  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
  %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
  return %0 : tensor<1x4x32x62xui8>
}

It does this by

  • converting the MaxPool2dConverter from OpRewriterPattern to OpConversion Pattern
  • adjusting the padding value to the the minimum unsigned value when the max_pool is unsigned
  • lowering to linalg.pooling_nhwc_max_unsigned (which uses arith.maxui) when the max_pool is unsigned

This PR allows to lower unsigned `tosa.max_pool2d` to linalg.
```
// CHECK-LABEL: @max_pool_ui8
func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
  // CHECK: arith.constant 0
  // CHECK: linalg.pooling_nhwc_max_unsigned {{.*}} : (tensor<1x4x32x62xi8>) -> tensor<1x4x32x62xi8>
  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
  %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
  return %0 : tensor<1x4x32x62xui8>
}
```
It does this by
- converting the MaxPool2dConverter from OpRewriterPattern to OpConversion Pattern
- adjusting the padding value to the the minimum unsigned value when the max_pool is unsigned
- lowering to `linalg.pooling_nhwc_max_unsigned` (which uses `arith.maxui`) when the max_pool is unsigned
@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Matthias Gehre (mgehre-amd)

Changes

This PR allows to lower unsigned tosa.max_pool2d to linalg.

// CHECK-LABEL: @<!-- -->max_pool_ui8
func.func @<!-- -->max_pool_ui8(%arg0: tensor&lt;1x6x34x62xui8&gt;) -&gt; tensor&lt;1x4x32x62xui8&gt; {
  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor&lt;1x6x34x62xui8&gt; to tensor&lt;1x6x34x62xi8&gt;
  // CHECK: arith.constant 0
  // CHECK: linalg.pooling_nhwc_max_unsigned {{.*}} : (tensor&lt;1x4x32x62xi8&gt;) -&gt; tensor&lt;1x4x32x62xi8&gt;
  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor&lt;1x4x32x62xi8&gt; to tensor&lt;1x4x32x62xui8&gt;
  %0 = tosa.max_pool2d %arg0 {pad = array&lt;i64: 0, 0, 0, 0&gt;, kernel = array&lt;i64: 3, 3&gt;, stride = array&lt;i64: 1, 1&gt;} : (tensor&lt;1x6x34x62xui8&gt;) -&gt; tensor&lt;1x4x32x62xui8&gt;
  return %0 : tensor&lt;1x4x32x62xui8&gt;
}

It does this by

  • converting the MaxPool2dConverter from OpRewriterPattern to OpConversion Pattern
  • adjusting the padding value to the the minimum unsigned value when the max_pool is unsigned
  • lowering to linalg.pooling_nhwc_max_unsigned (which uses arith.maxui) when the max_pool is unsigned

Full diff: https://github.com/llvm/llvm-project/pull/123290.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h (+2-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+36-16)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp (+5-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+13)
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 1822016fc88fe6..a1eb22eba69877 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -52,7 +52,8 @@ void populateTosaToLinalgConversionPatterns(const TypeConverter &converter,
 
 /// Populates conversion passes from TOSA dialect to Linalg named operations.
 void populateTosaToLinalgNamedConversionPatterns(
-    RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
+    const TypeConverter &converter, RewritePatternSet *patterns,
+    const TosaToLinalgNamedOptions &options);
 
 } // namespace tosa
 } // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index d537aef5791031..b7af37d293ac1c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -695,17 +695,18 @@ class FullyConnectedConverter
   }
 };
 
-class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
+class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
 public:
-  using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
+  using OpConversionPattern::OpConversionPattern;
 
   // Compute the dynamic output sizes of the maxpool operation.
   static SmallVector<Value>
-  computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
+  computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
+                            ConversionPatternRewriter &rewriter) {
     TensorType resultTy = op.getType();
     Location loc = op.getLoc();
 
-    TypedValue<TensorType> input = op.getInput();
+    Value input = adaptor.getInput();
     ArrayRef<int64_t> kernel = op.getKernel();
     ArrayRef<int64_t> pad = op.getPad();
     ArrayRef<int64_t> stride = op.getStride();
@@ -744,16 +745,22 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
     return dynamicDims;
   }
 
-  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
-                                PatternRewriter &rewriter) const final {
+  LogicalResult
+  matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
-    TypedValue<TensorType> input = op.getInput();
-    ShapedType inputTy = input.getType();
+    Value input = adaptor.getInput();
+    ShapedType inputTy = cast<ShapedType>(input.getType());
 
-    ShapedType resultTy = op.getType();
+    bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
+    ShapedType resultTy =
+        cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
+    if (!resultTy)
+      return rewriter.notifyMatchFailure(op, "failed to convert type");
     Type resultETy = inputTy.getElementType();
 
-    SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
+    SmallVector<Value> dynamicDims =
+        computeDynamicOutputSizes(op, adaptor, rewriter);
 
     // Determine what the initial value needs to be for the max pool op.
     TypedAttr initialAttr;
@@ -762,7 +769,10 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
           resultETy, APFloat::getLargest(
                          cast<FloatType>(resultETy).getFloatSemantics(), true));
 
-    if (isa<IntegerType>(resultETy))
+    else if (isUnsigned)
+      initialAttr = rewriter.getIntegerAttr(
+          resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
+    else if (isa<IntegerType>(resultETy))
       initialAttr = rewriter.getIntegerAttr(
           resultETy,
           APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
@@ -798,9 +808,15 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
     Value fakeWindowDims =
         rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
 
-    rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
-        op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
-        filledEmptyTensor, strideAttr, dilationAttr);
+    if (isUnsigned) {
+      rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
+          op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
+          filledEmptyTensor, strideAttr, dilationAttr);
+    } else {
+      rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
+          op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
+          filledEmptyTensor, strideAttr, dilationAttr);
+    }
     return success();
   }
 };
@@ -1070,7 +1086,8 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
-    RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
+    const TypeConverter &converter, RewritePatternSet *patterns,
+    const TosaToLinalgNamedOptions &options) {
   if (options.preferConv2DKernelLayoutHWCF) {
     patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
                                 linalg::Conv2DNhwcHwcfQOp>>(
@@ -1085,10 +1102,13 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
       ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
       DepthwiseConvConverter,
       MatMulConverter,
-      MaxPool2dConverter,
       AvgPool2dConverter,
       FullyConnectedConverter,
       TransposeConverter
   >(patterns->getContext());
+
+  patterns->add<
+      MaxPool2dConverter
+    >(converter, patterns->getContext());
   // clang-format on
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 096969391e51b9..7d943b3779fb02 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -47,6 +47,9 @@ struct TosaToLinalgNamed
   }
 
   void runOnOperation() override {
+    TypeConverter converter;
+    tosa::populateTosaTypeConversion(converter);
+
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
     target.addLegalDialect<linalg::LinalgDialect, tosa::TosaDialect,
@@ -67,7 +70,8 @@ struct TosaToLinalgNamed
     FunctionOpInterface func = getOperation();
     TosaToLinalgNamedOptions options;
     options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF;
-    tosa::populateTosaToLinalgNamedConversionPatterns(&patterns, options);
+    tosa::populateTosaToLinalgNamedConversionPatterns(converter, &patterns,
+                                                      options);
     if (failed(applyFullConversion(func, target, std::move(patterns))))
       signalPassFailure();
   }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 453a8610e7169a..5eeaebb384e408 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -200,6 +200,19 @@ func.func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
   return
 }
 
+// CHECK-LABEL: @max_pool_ui8
+func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
+  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
+  // CHECK: arith.constant 0
+  // CHECK: linalg.pooling_nhwc_max_unsigned
+  // CHECK-SAME: ins({{.*}} : tensor<1x6x34x62xi8>, tensor<3x3xi8>)
+  // CHECK-SAME: outs({{.*}} : tensor<1x4x32x62xi8>)
+  // CHECK-SAME: -> tensor<1x4x32x62xi8>
+  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
+  %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
+  return %0 : tensor<1x4x32x62xui8>
+}
+
 // CHECK-LABEL: @max_pool_i16
 func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () {
   // CHECK: arith.constant -32768

@Jerry-Ge
Copy link
Member

I wonder is unsigned tosa.max_pool2d defined from the spec? https://www.mlplatform.org/tosa/tosa_spec.html#_max_pool2d

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good on my side and I do understand the need!
@eric-k256 @sjarus could you please have a look as well as not sure on the spec aspect?

@mgehre-amd
Copy link
Contributor Author

I wonder is unsigned tosa.max_pool2d defined from the spec? mlplatform.org/tosa/tosa_spec.html#_max_pool2d

We had previously agreed (in a TOSA community meeting) to allow more data types in MLIR where useful. The TOSA validation pass can be used to reject them if strict adherence to the spec is needed.
Examples are:
#86509
#91734
#91749

@GeorgeARM
Copy link
Contributor

I wonder is unsigned tosa.max_pool2d defined from the spec? mlplatform.org/tosa/tosa_spec.html#_max_pool2d

We had previously agreed (in a TOSA community meeting) to allow more data types in MLIR where useful. The TOSA validation pass can be used to reject them if strict adherence to the spec is needed. Examples are: #86509 #91734 #91749

Makes sense @mgehre-amd and as mentioned above understand the need. Will proceed with approving this.

@mgehre-amd mgehre-amd merged commit 5ce271e into llvm:main Jan 20, 2025
12 checks passed
@mgehre-amd mgehre-amd deleted the matthias.tosa_maxpool_unsigned branch January 20, 2025 12:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants