Skip to content

Commit 55b8b59

Browse files
authored
Merge pull request #172 from Xilinx/matthias.tosa_to_linalg_unsigned_maxpool
TosaToLinalg: Support unsigned tosa.max_pool2d (FXML-4556)
2 parents ac54176 + 312d8e5 commit 55b8b59

File tree

6 files changed

+87
-39
lines changed

6 files changed

+87
-39
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ void populateTosaToLinalgConversionPatterns(TypeConverter &converter,
5252

5353
/// Populates conversion passes from TOSA dialect to Linalg named operations.
5454
void populateTosaToLinalgNamedConversionPatterns(
55-
RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
55+
TypeConverter &converter, RewritePatternSet *patterns,
56+
const TosaToLinalgNamedOptions &options);
57+
58+
void populateTosaToLinalgTypeConversion(TypeConverter &converter);
5659

5760
} // namespace tosa
5861
} // namespace mlir

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2756,3 +2756,37 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
27562756

27572757
// clang-format on
27582758
}
2759+
2760+
void mlir::tosa::populateTosaToLinalgTypeConversion(TypeConverter &converter) {
2761+
converter.addConversion([&](Type type) -> std::optional<Type> {
2762+
if (type.isUnsignedInteger()) {
2763+
return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(),
2764+
IntegerType::SignednessSemantics::Signless);
2765+
}
2766+
return type;
2767+
});
2768+
converter.addConversion([&](TensorType type) -> std::optional<Type> {
2769+
auto converted = converter.convertType(type.getElementType());
2770+
if (!converted)
2771+
return {};
2772+
return type.clone(converted);
2773+
});
2774+
converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
2775+
ValueRange inputs,
2776+
Location loc) -> std::optional<Value> {
2777+
if (inputs.size() != 1)
2778+
return std::nullopt;
2779+
2780+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
2781+
.getResult(0);
2782+
});
2783+
converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
2784+
ValueRange inputs,
2785+
Location loc) -> std::optional<Value> {
2786+
if (inputs.size() != 1)
2787+
return std::nullopt;
2788+
2789+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
2790+
.getResult(0);
2791+
});
2792+
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -760,17 +760,23 @@ class FullyConnectedConverter
760760
}
761761
};
762762

763-
class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
763+
class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
764764
public:
765-
using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
765+
using OpConversionPattern::OpConversionPattern;
766766

767-
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
768-
PatternRewriter &rewriter) const final {
767+
LogicalResult
768+
matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
769+
ConversionPatternRewriter &rewriter) const final {
769770
Location loc = op.getLoc();
770-
Value input = op.getInput();
771+
Value input = adaptor.getInput();
771772
ShapedType inputTy = cast<ShapedType>(input.getType());
772773

773-
ShapedType resultTy = cast<ShapedType>(op.getType());
774+
bool isUnsigned =
775+
cast<ShapedType>(op.getType()).getElementType().isUnsignedInteger();
776+
ShapedType resultTy =
777+
cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
778+
if (!resultTy)
779+
return rewriter.notifyMatchFailure(op, "failed to convert type");
774780
Type resultETy = inputTy.getElementType();
775781

776782
auto dynamicDimsOr =
@@ -786,7 +792,10 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
786792
resultETy, APFloat::getLargest(
787793
cast<FloatType>(resultETy).getFloatSemantics(), true));
788794

789-
if (isa<IntegerType>(resultETy))
795+
else if (isUnsigned)
796+
initialAttr = rewriter.getIntegerAttr(
797+
resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
798+
else if (isa<IntegerType>(resultETy))
790799
initialAttr = rewriter.getIntegerAttr(
791800
resultETy,
792801
APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
@@ -823,9 +832,15 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
823832
Value fakeWindowDims =
824833
rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
825834

826-
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
827-
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
828-
filledEmptyTensor, strideAttr, dilationAttr);
835+
if (isUnsigned) {
836+
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
837+
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
838+
filledEmptyTensor, strideAttr, dilationAttr);
839+
} else {
840+
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
841+
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
842+
filledEmptyTensor, strideAttr, dilationAttr);
843+
}
829844
return success();
830845
}
831846
};
@@ -1091,7 +1106,8 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
10911106
} // namespace
10921107

10931108
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
1094-
RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
1109+
TypeConverter &converter, RewritePatternSet *patterns,
1110+
const TosaToLinalgNamedOptions &options) {
10951111
if (options.preferConv2DKernelLayoutHWCF) {
10961112
patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
10971113
linalg::Conv2DNhwcHwcfQOp>>(
@@ -1105,11 +1121,13 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
11051121
// clang-format off
11061122
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
11071123
DepthwiseConvConverter,
1108-
MaxPool2dConverter,
11091124
AvgPool2dConverter,
11101125
FullyConnectedConverter,
11111126
TransposeConverter
11121127
>(patterns->getContext());
1128+
patterns->add<
1129+
MaxPool2dConverter
1130+
>(converter, patterns->getContext());
11131131
patterns->add<
11141132
MatMulConverter>(patterns->getContext(), options.useMatmulForSingleBatch);
11151133
// clang-format on

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct TosaToLinalgNamed
4747
}
4848

4949
void runOnOperation() override {
50+
TypeConverter converter;
51+
mlir::tosa::populateTosaToLinalgTypeConversion(converter);
52+
5053
RewritePatternSet patterns(&getContext());
5154
ConversionTarget target(getContext());
5255
target.addLegalDialect<linalg::LinalgDialect, tosa::TosaDialect,
@@ -68,7 +71,8 @@ struct TosaToLinalgNamed
6871
TosaToLinalgNamedOptions options;
6972
options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF;
7073
options.useMatmulForSingleBatch = useMatmulForSingleBatch;
71-
tosa::populateTosaToLinalgNamedConversionPatterns(&patterns, options);
74+
tosa::populateTosaToLinalgNamedConversionPatterns(converter, &patterns,
75+
options);
7276
if (failed(applyFullConversion(func, target, std::move(patterns))))
7377
signalPassFailure();
7478
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,31 +46,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
4646

4747
void runOnOperation() override {
4848
TypeConverter converter;
49-
converter.addConversion([&](Type type) -> std::optional<Type> {
50-
if (type.isUnsignedInteger()) {
51-
return IntegerType::get(&getContext(), type.getIntOrFloatBitWidth(),
52-
IntegerType::SignednessSemantics::Signless);
53-
}
54-
return type;
55-
});
56-
converter.addConversion([&](TensorType type) -> std::optional<Type> {
57-
auto converted = converter.convertType(type.getElementType());
58-
if (!converted)
59-
return {};
60-
return type.clone(converted);
61-
});
62-
converter.addConversion(
63-
[&converter](FunctionType ty) -> std::optional<Type> {
64-
SmallVector<Type> inputs;
65-
if (failed(converter.convertTypes(ty.getInputs(), inputs)))
66-
return std::nullopt;
67-
68-
SmallVector<Type> results;
69-
if (failed(converter.convertTypes(ty.getResults(), results)))
70-
return std::nullopt;
71-
72-
return FunctionType::get(ty.getContext(), inputs, results);
73-
});
49+
mlir::tosa::populateTosaToLinalgTypeConversion(converter);
7450

7551
RewritePatternSet patterns(&getContext());
7652
ConversionTarget target(getContext());

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,19 @@ func.func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
199199
return
200200
}
201201

202+
// CHECK-LABEL: @max_pool_ui8
203+
func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
204+
// CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
205+
// CHECK: arith.constant 0
206+
// CHECK: linalg.pooling_nhwc_max_unsigned
207+
// CHECK-SAME: ins({{.*}} : tensor<1x6x34x62xi8>, tensor<3x3xi8>)
208+
// CHECK-SAME: outs({{.*}} : tensor<1x4x32x62xi8>)
209+
// CHECK-SAME: -> tensor<1x4x32x62xi8>
210+
// CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
211+
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
212+
return %0 : tensor<1x4x32x62xui8>
213+
}
214+
202215
// CHECK-LABEL: @max_pool_i16
203216
func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () {
204217
// CHECK: arith.constant -32768

0 commit comments

Comments
 (0)