Skip to content

TosaToTensor: Support reshape on tensors of unsigned integer #91734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 28, 2024

Conversation

mgehre-amd
Copy link
Contributor

This adds

  • mlir::tosa::populateTosaToLinalgTypeConversion which converts tensors of unsigned integers into tensors of signless integers
  • modifies the tosa.reshape lowering in TosaToTensor to use the type converter correctly

I choose to implement the type converter in mlir/Conversion/TosaToLinalg/TosaToLinalg.h instead of mlir/Conversion/TosaToTensor/TosaToTensor.h because I need the same type converter in the TosaToLinalg lowerings (future PR).
Alternatively, I could duplicate the type converter so it exists both in TosaToLinalg and TosaToTensor. Let me know if you prefer that.

@llvmbot
Copy link
Member

llvmbot commented May 10, 2024

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Matthias Gehre (mgehre-amd)

Changes

This adds

  • mlir::tosa::populateTosaToLinalgTypeConversion which converts tensors of unsigned integers into tensors of signless integers
  • modifies the tosa.reshape lowering in TosaToTensor to use the type converter correctly

I choose to implement the type converter in mlir/Conversion/TosaToLinalg/TosaToLinalg.h instead of mlir/Conversion/TosaToTensor/TosaToTensor.h because I need the same type converter in the TosaToLinalg lowerings (future PR).
Alternatively, I could duplicate the type converter so it exists both in TosaToLinalg and TosaToTensor. Let me know if you prefer that.


Full diff: https://github.com/llvm/llvm-project/pull/91734.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h (+3)
  • (modified) mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h (+3-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+34)
  • (modified) mlir/lib/Conversion/TosaToTensor/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+20-13)
  • (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp (+5-1)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+14)
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 5fd77c8a0211a..d3024c7389b9c 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -18,6 +18,7 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class TypeConverter;
 
 #define GEN_PASS_DECL_TOSATOLINALG
 #define GEN_PASS_DECL_TOSATOLINALGNAMED
@@ -52,6 +53,8 @@ void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
 void populateTosaToLinalgNamedConversionPatterns(
     RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
 
+void populateTosaToLinalgTypeConversion(TypeConverter &converter);
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
index 3953c83f3aa10..76a4b1b156336 100644
--- a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
+++ b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
@@ -16,6 +16,7 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class TypeConverter;
 
 #define GEN_PASS_DECL_TOSATOTENSOR
 #include "mlir/Conversion/Passes.h.inc"
@@ -24,7 +25,8 @@ namespace tosa {
 
 std::unique_ptr<Pass> createTosaToTensor();
 
-void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToTensorConversionPatterns(TypeConverter &converter,
+                                            RewritePatternSet *patterns);
 
 } // namespace tosa
 } // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e6ba6e6bc602d..dcb15012bda88 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2617,3 +2617,37 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       TileConverter>(patterns->getContext());
   // clang-format on
 }
+
+void mlir::tosa::populateTosaToLinalgTypeConversion(TypeConverter &converter) {
+  converter.addConversion([&](Type type) -> std::optional<Type> {
+    if (type.isUnsignedInteger()) {
+      return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(),
+                              IntegerType::SignednessSemantics::Signless);
+    }
+    return type;
+  });
+  converter.addConversion([&](TensorType type) -> std::optional<Type> {
+    auto converted = converter.convertType(type.getElementType());
+    if (!converted)
+      return {};
+    return type.clone(converted);
+  });
+  converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+                                         ValueRange inputs,
+                                         Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1)
+      return std::nullopt;
+
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+  converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+                                         ValueRange inputs,
+                                         Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1)
+      return std::nullopt;
+
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+}
diff --git a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
index 2870baa20757b..b1e7c9cba1a78 100644
--- a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRTosaToTensor
   MLIRIR
   MLIRPass
   MLIRTosaDialect
+  MLIRTosaToLinalg
   MLIRTosaTransforms
   MLIRSupport
   )
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index cd6da35582469..33f388faf6648 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -225,8 +225,17 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     auto loc = reshape.getLoc();
-    auto resultType = reshape.getResult().getType();
-    auto input = reshape.getInput1();
+    auto resultType = cast_if_present<ShapedType>(
+        getTypeConverter()->convertType(reshape.getType()));
+    if (!resultType) {
+      return rewriter.notifyMatchFailure(reshape.getLoc(),
+                                         "could not convert result type");
+    }
+    auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
+    if (!input) {
+      return rewriter.notifyMatchFailure(reshape.getLoc(),
+                                         "expected input type to be tensor");
+    }
     auto newShape = reshape.getNewShape();
 
     // Infer all intermediate types
@@ -289,12 +298,13 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
   }
 };
 
-class PadConverter : public OpRewritePattern<tosa::PadOp> {
+class PadConverter : public OpConversionPattern<tosa::PadOp> {
 public:
-  using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
+  using OpConversionPattern::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(tosa::PadOp padOp,
-                                PatternRewriter &rewriter) const final {
+  LogicalResult
+  matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
     auto loc = padOp.getLoc();
     auto input = padOp.getInput1();
     auto padding = padOp.getPadding();
@@ -429,11 +439,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
 } // namespace
 
 void mlir::tosa::populateTosaToTensorConversionPatterns(
-    RewritePatternSet *patterns) {
-  patterns->add<
-    ConcatConverter,
-    PadConverter,
-    ReshapeConverter,
-    SliceConverter
-  >(patterns->getContext());
+    TypeConverter &converter, RewritePatternSet *patterns) {
+  patterns
+      ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
+          converter, patterns->getContext());
 }
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 50dc55667fb94..9ae5edcce291e 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <mlir/Conversion/TosaToLinalg/TosaToLinalg.h>
 
 namespace mlir {
 #define GEN_PASS_DEF_TOSATOTENSOR
@@ -42,7 +43,10 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
     target.addLegalDialect<arith::ArithDialect>();
     target.addLegalDialect<tensor::TensorDialect>();
 
-    mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
+    TypeConverter converter;
+    mlir::tosa::populateTosaToLinalgTypeConversion(converter);
+
+    mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns);
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index b8c3d56f21f10..2eddde9a55660 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -405,6 +405,20 @@ func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) ->
 
 // -----
 
+// CHECK-LABEL: @test_reshape_samerank_unsigned
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>)
+func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> {
+  // CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8>
+  // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8>
+  // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8>
+  // CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xui8>) -> tensor<2x3xui8>
+  // CHECK-NEXT: return %[[CAST2]]
+  return %0 : tensor<2x3xui8>
+}
+
+// -----
+
 // CHECK-LABEL: func @slice
 func.func @slice(%arg0: tensor<6xf32>) ->() {
   // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]

@mgehre-amd
Copy link
Contributor Author

mgehre-amd commented May 10, 2024

This is how TosaToLinalg would use the type converter: #91749

Copy link

@leandron leandron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @mgehre-amd

@sjarus
Copy link
Contributor

sjarus commented May 20, 2024

Would it make sense to put the type converter inside the TOSA dialect itself as a utility function ? This would make it broadly useful to pathways that may not use Linalg.

@mgehre-amd mgehre-amd requested a review from sjarus May 24, 2024 06:54
Copy link
Contributor

@sjarus sjarus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for refactoring this!

@mgehre-amd mgehre-amd merged commit af22e27 into llvm:main May 28, 2024
6 of 8 checks passed
@mgehre-amd mgehre-amd deleted the mgehre.tosa_reshape_unsigned branch May 28, 2024 15:59
vg0204 pushed a commit to vg0204/llvm-project that referenced this pull request May 29, 2024
)

This adds 
- `mlir::tosa::populateTosaToLinalgTypeConversion` which converts
tensors of unsigned integers into tensors of signless integers
- modifies the `tosa.reshape` lowering in TosaToTensor to use the type
converter correctly

I choose to implement the type converter in
`mlir/Conversion/TosaToLinalg/TosaToLinalg.h` instead of
`mlir/Conversion/TosaToTensor/TosaToTensor.h` because I need the same
type converter in the TosaToLinalg lowerings (future PR).
Alternatively, I could duplicate the type converter so it exists both in
TosaToLinalg and TosaToTensor. Let me know if you prefer that.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants