Skip to content

Commit af22e27

Browse files
authored
TosaToTensor: Support reshape on tensors of unsigned integer (#91734)
This adds - `mlir::tosa::populateTosaToLinalgTypeConversion` which converts tensors of unsigned integers into tensors of signless integers - modifies the `tosa.reshape` lowering in TosaToTensor to use the type converter correctly I choose to implement the type converter in `mlir/Conversion/TosaToLinalg/TosaToLinalg.h` instead of `mlir/Conversion/TosaToTensor/TosaToTensor.h` because I need the same type converter in the TosaToLinalg lowerings (future PR). Alternatively, I could duplicate the type converter so it exists both in TosaToLinalg and TosaToTensor. Let me know if you prefer that.
1 parent f284af4 commit af22e27

File tree

7 files changed

+97
-15
lines changed

7 files changed

+97
-15
lines changed

mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Pass/Pass.h"
1717

1818
namespace mlir {
19+
class TypeConverter;
1920

2021
#define GEN_PASS_DECL_TOSATOTENSOR
2122
#include "mlir/Conversion/Passes.h.inc"
@@ -24,7 +25,8 @@ namespace tosa {
2425

2526
std::unique_ptr<Pass> createTosaToTensor();
2627

27-
void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns);
28+
void populateTosaToTensorConversionPatterns(TypeConverter &converter,
29+
RewritePatternSet *patterns);
2830

2931
} // namespace tosa
3032
} // namespace mlir

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Pass/Pass.h"
1919

2020
namespace mlir {
21+
class TypeConverter;
2122
namespace tosa {
2223

2324
#define GEN_PASS_DECL
@@ -38,6 +39,8 @@ void populateTosaConstantReduction(MLIRContext *ctx,
3839
RewritePatternSet &patterns,
3940
bool aggressiveReduceConstant);
4041

42+
void populateTosaTypeConversion(TypeConverter &converter);
43+
4144
std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
4245
std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass(
4346
const TosaLayerwiseConstantFoldPassOptions &options);

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,17 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
224224
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
225225
ConversionPatternRewriter &rewriter) const final {
226226
auto loc = reshape.getLoc();
227-
auto resultType = reshape.getResult().getType();
228-
auto input = reshape.getInput1();
227+
auto resultType = cast_if_present<ShapedType>(
228+
getTypeConverter()->convertType(reshape.getType()));
229+
if (!resultType) {
230+
return rewriter.notifyMatchFailure(reshape.getLoc(),
231+
"could not convert result type");
232+
}
233+
auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
234+
if (!input) {
235+
return rewriter.notifyMatchFailure(reshape.getLoc(),
236+
"expected input type to be tensor");
237+
}
229238
auto newShape = reshape.getNewShape();
230239

231240
// Infer all intermediate types
@@ -288,12 +297,13 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
288297
}
289298
};
290299

291-
class PadConverter : public OpRewritePattern<tosa::PadOp> {
300+
class PadConverter : public OpConversionPattern<tosa::PadOp> {
292301
public:
293-
using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
302+
using OpConversionPattern::OpConversionPattern;
294303

295-
LogicalResult matchAndRewrite(tosa::PadOp padOp,
296-
PatternRewriter &rewriter) const final {
304+
LogicalResult
305+
matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
306+
ConversionPatternRewriter &rewriter) const final {
297307
auto loc = padOp.getLoc();
298308
auto input = padOp.getInput1();
299309
auto padding = padOp.getPadding();
@@ -428,11 +438,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
428438
} // namespace
429439

430440
void mlir::tosa::populateTosaToTensorConversionPatterns(
431-
RewritePatternSet *patterns) {
432-
patterns->add<
433-
ConcatConverter,
434-
PadConverter,
435-
ReshapeConverter,
436-
SliceConverter
437-
>(patterns->getContext());
441+
TypeConverter &converter, RewritePatternSet *patterns) {
442+
patterns
443+
->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
444+
converter, patterns->getContext());
438445
}

mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
4242
target.addLegalDialect<arith::ArithDialect>();
4343
target.addLegalDialect<tensor::TensorDialect>();
4444

45-
mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
45+
TypeConverter converter;
46+
mlir::tosa::populateTosaTypeConversion(converter);
47+
48+
mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns);
4649

4750
if (failed(applyPartialConversion(getOperation(), target,
4851
std::move(patterns))))

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
77
TosaLayerwiseConstantFoldPass.cpp
88
TosaMakeBroadcastable.cpp
99
TosaOptionalDecompositions.cpp
10+
TosaTypeConverters.cpp
1011
TosaValidation.cpp
1112

1213
ADDITIONAL_HEADER_DIRS
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
2+
//===- TosaTypeConverters.cpp ---------------------------------------------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// Type converters for lowering TOSA to linalg/arith.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
16+
#include "mlir/Transforms/DialectConversion.h"
17+
18+
using namespace mlir;
19+
20+
void mlir::tosa::populateTosaTypeConversion(TypeConverter &converter) {
21+
converter.addConversion([&](Type type) -> std::optional<Type> {
22+
if (type.isUnsignedInteger()) {
23+
return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(),
24+
IntegerType::SignednessSemantics::Signless);
25+
}
26+
return type;
27+
});
28+
converter.addConversion([&](TensorType type) -> std::optional<Type> {
29+
auto converted = converter.convertType(type.getElementType());
30+
if (!converted)
31+
return {};
32+
return type.clone(converted);
33+
});
34+
converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
35+
ValueRange inputs,
36+
Location loc) -> std::optional<Value> {
37+
if (inputs.size() != 1)
38+
return std::nullopt;
39+
40+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
41+
.getResult(0);
42+
});
43+
converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
44+
ValueRange inputs,
45+
Location loc) -> std::optional<Value> {
46+
if (inputs.size() != 1)
47+
return std::nullopt;
48+
49+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
50+
.getResult(0);
51+
});
52+
}

mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,20 @@ func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) ->
420420

421421
// -----
422422

423+
// CHECK-LABEL: @test_reshape_samerank_unsigned
424+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>)
425+
func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> {
426+
// CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8>
427+
// CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8>
428+
// CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8>
429+
// CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8
430+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xui8>) -> tensor<2x3xui8>
431+
// CHECK-NEXT: return %[[CAST2]]
432+
return %0 : tensor<2x3xui8>
433+
}
434+
435+
// -----
436+
423437
// CHECK-LABEL: func @slice
424438
func.func @slice(%arg0: tensor<6xf32>) ->() {
425439
// CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]

0 commit comments

Comments
 (0)