Skip to content

Commit b06875a

Browse files
committed
TosaToTensor: Support reshape on tensors of unsigned integer
1 parent 45fed80 commit b06875a

File tree

7 files changed

+80
-15
lines changed

7 files changed

+80
-15
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Pass/Pass.h"
1919

2020
namespace mlir {
21+
class TypeConverter;
2122

2223
#define GEN_PASS_DECL_TOSATOLINALG
2324
#define GEN_PASS_DECL_TOSATOLINALGNAMED
@@ -52,6 +53,8 @@ void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
5253
void populateTosaToLinalgNamedConversionPatterns(
5354
RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
5455

56+
void populateTosaToLinalgTypeConversion(TypeConverter &converter);
57+
5558
} // namespace tosa
5659
} // namespace mlir
5760

mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Pass/Pass.h"
1717

1818
namespace mlir {
19+
class TypeConverter;
1920

2021
#define GEN_PASS_DECL_TOSATOTENSOR
2122
#include "mlir/Conversion/Passes.h.inc"
@@ -24,7 +25,8 @@ namespace tosa {
2425

2526
std::unique_ptr<Pass> createTosaToTensor();
2627

27-
void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns);
28+
void populateTosaToTensorConversionPatterns(TypeConverter &converter,
29+
RewritePatternSet *patterns);
2830

2931
} // namespace tosa
3032
} // namespace mlir

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2617,3 +2617,37 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
26172617
TileConverter>(patterns->getContext());
26182618
// clang-format on
26192619
}
2620+
2621+
void mlir::tosa::populateTosaToLinalgTypeConversion(TypeConverter &converter) {
2622+
converter.addConversion([&](Type type) -> std::optional<Type> {
2623+
if (type.isUnsignedInteger()) {
2624+
return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(),
2625+
IntegerType::SignednessSemantics::Signless);
2626+
}
2627+
return type;
2628+
});
2629+
converter.addConversion([&](TensorType type) -> std::optional<Type> {
2630+
auto converted = converter.convertType(type.getElementType());
2631+
if (!converted)
2632+
return {};
2633+
return type.clone(converted);
2634+
});
2635+
converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
2636+
ValueRange inputs,
2637+
Location loc) -> std::optional<Value> {
2638+
if (inputs.size() != 1)
2639+
return std::nullopt;
2640+
2641+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
2642+
.getResult(0);
2643+
});
2644+
converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
2645+
ValueRange inputs,
2646+
Location loc) -> std::optional<Value> {
2647+
if (inputs.size() != 1)
2648+
return std::nullopt;
2649+
2650+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
2651+
.getResult(0);
2652+
});
2653+
}

mlir/lib/Conversion/TosaToTensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRTosaToTensor
1515
MLIRIR
1616
MLIRPass
1717
MLIRTosaDialect
18+
MLIRTosaToLinalg
1819
MLIRTosaTransforms
1920
MLIRSupport
2021
)

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,17 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
225225
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
226226
ConversionPatternRewriter &rewriter) const final {
227227
auto loc = reshape.getLoc();
228-
auto resultType = reshape.getResult().getType();
229-
auto input = reshape.getInput1();
228+
auto resultType = cast_if_present<ShapedType>(
229+
getTypeConverter()->convertType(reshape.getType()));
230+
if (!resultType) {
231+
return rewriter.notifyMatchFailure(reshape.getLoc(),
232+
"could not convert result type");
233+
}
234+
auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
235+
if (!input) {
236+
return rewriter.notifyMatchFailure(reshape.getLoc(),
237+
"expected input type to be tensor");
238+
}
230239
auto newShape = reshape.getNewShape();
231240

232241
// Infer all intermediate types
@@ -289,12 +298,13 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
289298
}
290299
};
291300

292-
class PadConverter : public OpRewritePattern<tosa::PadOp> {
301+
class PadConverter : public OpConversionPattern<tosa::PadOp> {
293302
public:
294-
using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
303+
using OpConversionPattern::OpConversionPattern;
295304

296-
LogicalResult matchAndRewrite(tosa::PadOp padOp,
297-
PatternRewriter &rewriter) const final {
305+
LogicalResult
306+
matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
307+
ConversionPatternRewriter &rewriter) const final {
298308
auto loc = padOp.getLoc();
299309
auto input = padOp.getInput1();
300310
auto padding = padOp.getPadding();
@@ -429,11 +439,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
429439
} // namespace
430440

431441
void mlir::tosa::populateTosaToTensorConversionPatterns(
432-
RewritePatternSet *patterns) {
433-
patterns->add<
434-
ConcatConverter,
435-
PadConverter,
436-
ReshapeConverter,
437-
SliceConverter
438-
>(patterns->getContext());
442+
TypeConverter &converter, RewritePatternSet *patterns) {
443+
patterns
444+
->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
445+
converter, patterns->getContext());
439446
}

mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Pass/PassManager.h"
2121
#include "mlir/Transforms/DialectConversion.h"
2222
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23+
#include <mlir/Conversion/TosaToLinalg/TosaToLinalg.h>
2324

2425
namespace mlir {
2526
#define GEN_PASS_DEF_TOSATOTENSOR
@@ -42,7 +43,10 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
4243
target.addLegalDialect<arith::ArithDialect>();
4344
target.addLegalDialect<tensor::TensorDialect>();
4445

45-
mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
46+
TypeConverter converter;
47+
mlir::tosa::populateTosaToLinalgTypeConversion(converter);
48+
49+
mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns);
4650

4751
if (failed(applyPartialConversion(getOperation(), target,
4852
std::move(patterns))))

mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,20 @@ func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) ->
405405

406406
// -----
407407

408+
// CHECK-LABEL: @test_reshape_samerank_unsigned
409+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>)
410+
func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> {
411+
// CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8>
412+
// CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8>
413+
// CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8>
414+
// CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8
415+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xui8>) -> tensor<2x3xui8>
416+
// CHECK-NEXT: return %[[CAST2]]
417+
return %0 : tensor<2x3xui8>
418+
}
419+
420+
// -----
421+
408422
// CHECK-LABEL: func @slice
409423
func.func @slice(%arg0: tensor<6xf32>) ->() {
410424
// CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]

0 commit comments

Comments
 (0)