-
Notifications
You must be signed in to change notification settings - Fork 14.3k
TosaToTensor: Support reshape on tensors of unsigned integer #91734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Matthias Gehre (mgehre-amd) ChangesThis adds
I choose to implement the type converter in Full diff: https://github.com/llvm/llvm-project/pull/91734.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 5fd77c8a0211a..d3024c7389b9c 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -18,6 +18,7 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+class TypeConverter;
#define GEN_PASS_DECL_TOSATOLINALG
#define GEN_PASS_DECL_TOSATOLINALGNAMED
@@ -52,6 +53,8 @@ void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
void populateTosaToLinalgNamedConversionPatterns(
RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
+void populateTosaToLinalgTypeConversion(TypeConverter &converter);
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
index 3953c83f3aa10..76a4b1b156336 100644
--- a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
+++ b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
@@ -16,6 +16,7 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+class TypeConverter;
#define GEN_PASS_DECL_TOSATOTENSOR
#include "mlir/Conversion/Passes.h.inc"
@@ -24,7 +25,8 @@ namespace tosa {
std::unique_ptr<Pass> createTosaToTensor();
-void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToTensorConversionPatterns(TypeConverter &converter,
+ RewritePatternSet *patterns);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e6ba6e6bc602d..dcb15012bda88 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2617,3 +2617,37 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
TileConverter>(patterns->getContext());
// clang-format on
}
+
+void mlir::tosa::populateTosaToLinalgTypeConversion(TypeConverter &converter) {
+ converter.addConversion([&](Type type) -> std::optional<Type> {
+ if (type.isUnsignedInteger()) {
+ return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(),
+ IntegerType::SignednessSemantics::Signless);
+ }
+ return type;
+ });
+ converter.addConversion([&](TensorType type) -> std::optional<Type> {
+ auto converted = converter.convertType(type.getElementType());
+ if (!converted)
+ return {};
+ return type.clone(converted);
+ });
+ converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ if (inputs.size() != 1)
+ return std::nullopt;
+
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ });
+ converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ if (inputs.size() != 1)
+ return std::nullopt;
+
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ });
+}
diff --git a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
index 2870baa20757b..b1e7c9cba1a78 100644
--- a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRTosaToTensor
MLIRIR
MLIRPass
MLIRTosaDialect
+ MLIRTosaToLinalg
MLIRTosaTransforms
MLIRSupport
)
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index cd6da35582469..33f388faf6648 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -225,8 +225,17 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = reshape.getLoc();
- auto resultType = reshape.getResult().getType();
- auto input = reshape.getInput1();
+ auto resultType = cast_if_present<ShapedType>(
+ getTypeConverter()->convertType(reshape.getType()));
+ if (!resultType) {
+ return rewriter.notifyMatchFailure(reshape.getLoc(),
+ "could not convert result type");
+ }
+ auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
+ if (!input) {
+ return rewriter.notifyMatchFailure(reshape.getLoc(),
+ "expected input type to be tensor");
+ }
auto newShape = reshape.getNewShape();
// Infer all intermediate types
@@ -289,12 +298,13 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
}
};
-class PadConverter : public OpRewritePattern<tosa::PadOp> {
+class PadConverter : public OpConversionPattern<tosa::PadOp> {
public:
- using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
+ using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(tosa::PadOp padOp,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
auto loc = padOp.getLoc();
auto input = padOp.getInput1();
auto padding = padOp.getPadding();
@@ -429,11 +439,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
} // namespace
void mlir::tosa::populateTosaToTensorConversionPatterns(
- RewritePatternSet *patterns) {
- patterns->add<
- ConcatConverter,
- PadConverter,
- ReshapeConverter,
- SliceConverter
- >(patterns->getContext());
+ TypeConverter &converter, RewritePatternSet *patterns) {
+ patterns
+ ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
+ converter, patterns->getContext());
}
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 50dc55667fb94..9ae5edcce291e 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -20,6 +20,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <mlir/Conversion/TosaToLinalg/TosaToLinalg.h>
namespace mlir {
#define GEN_PASS_DEF_TOSATOTENSOR
@@ -42,7 +43,10 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<tensor::TensorDialect>();
- mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
+ TypeConverter converter;
+ mlir::tosa::populateTosaToLinalgTypeConversion(converter);
+
+ mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index b8c3d56f21f10..2eddde9a55660 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -405,6 +405,20 @@ func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) ->
// -----
+// CHECK-LABEL: @test_reshape_samerank_unsigned
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>)
+func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> {
+ // CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8>
+ // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8>
+ // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8>
+ // CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xui8>) -> tensor<2x3xui8>
+ // CHECK-NEXT: return %[[CAST2]]
+ return %0 : tensor<2x3xui8>
+}
+
+// -----
+
// CHECK-LABEL: func @slice
func.func @slice(%arg0: tensor<6xf32>) ->() {
// CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
|
This is how TosaToLinalg would use the type converter: #91749 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @mgehre-amd
Would it make sense to put the type converter inside the TOSA dialect itself as a utility function ? This would make it broadly useful to pathways that may not use Linalg. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for refactoring this!
) This adds - `mlir::tosa::populateTosaToLinalgTypeConversion` which converts tensors of unsigned integers into tensors of signless integers - modifies the `tosa.reshape` lowering in TosaToTensor to use the type converter correctly I choose to implement the type converter in `mlir/Conversion/TosaToLinalg/TosaToLinalg.h` instead of `mlir/Conversion/TosaToTensor/TosaToTensor.h` because I need the same type converter in the TosaToLinalg lowerings (future PR). Alternatively, I could duplicate the type converter so it exists both in TosaToLinalg and TosaToTensor. Let me know if you prefer that.
This adds
mlir::tosa::populateTosaToLinalgTypeConversion
which converts tensors of unsigned integers into tensors of signless integerstosa.reshape
lowering in TosaToTensor to use the type converter correctlyI choose to implement the type converter in
mlir/Conversion/TosaToLinalg/TosaToLinalg.h
instead ofmlir/Conversion/TosaToTensor/TosaToTensor.h
because I need the same type converter in the TosaToLinalg lowerings (future PR).Alternatively, I could duplicate the type converter so it exists both in TosaToLinalg and TosaToTensor. Let me know if you prefer that.