Skip to content

Commit 41f64ca

Browse files
committed
TosaToLinalg: Fix unsigned tosa.clamp
Plump the TypeConverter into PointwiseConverter, and emit unsigned comparisons when the input type is unsigned.
1 parent b06875a commit 41f64ca

File tree

8 files changed

+83
-37
lines changed

8 files changed

+83
-37
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ void addTosaToLinalgPasses(
4747
void registerTosaToLinalgPipelines();
4848

4949
/// Populates conversion passes from TOSA dialect to Linalg dialect.
50-
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
50+
void populateTosaToLinalgConversionPatterns(TypeConverter &converter,
51+
RewritePatternSet *patterns);
5152

5253
/// Populates conversion passes from TOSA dialect to Linalg named operations.
5354
void populateTosaToLinalgNamedConversionPatterns(

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
3737
// Takes the parameters for a clamp and turns it into a series of ops for
3838
// integer inputs.
3939
Value clampIntHelper(Location loc, Value arg, Value min, Value max,
40-
OpBuilder &rewriter);
40+
OpBuilder &rewriter, bool isUnsigned);
4141

4242
// Determines whether the integer value falls witin the range of integer type.
4343
bool validIntegerRange(IntegerType ty, int64_t value);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,9 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
4646
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
4747
}
4848

49-
static Value
50-
createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
51-
ArrayRef<Type> resultTypes,
52-
PatternRewriter &rewriter) {
49+
static Value createLinalgBodyCalculationForElementwiseOp(
50+
Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51+
ConversionPatternRewriter &rewriter) {
5352
Location loc = op->getLoc();
5453
auto elementTy =
5554
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
@@ -186,7 +185,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
186185
Value max = rewriter.create<arith::ConstantIntOp>(
187186
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
188187
intermediateType);
189-
auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
188+
auto clamp =
189+
clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
190190

191191
// Truncate to the final value.
192192
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
@@ -390,10 +390,15 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
390390
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
391391

392392
if (intTy.isUnsignedInteger()) {
393+
if (intTy.getIntOrFloatBitWidth() > 63) {
394+
(void)rewriter.notifyMatchFailure(
395+
op, "support for 64-bit or larger integers is not implemented");
396+
return {};
397+
}
393398
min = std::max(min, (int64_t)0);
394-
max = std::min(
395-
max,
396-
APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
399+
max = std::min(max,
400+
(int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
401+
.getZExtValue());
397402
} else {
398403
min =
399404
std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
@@ -407,7 +412,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
407412
loc, min, intTy.getIntOrFloatBitWidth());
408413
auto maxVal = rewriter.create<arith::ConstantIntOp>(
409414
loc, max, intTy.getIntOrFloatBitWidth());
410-
return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
415+
return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
416+
intTy.isUnsignedInteger());
411417
}
412418

413419
// tosa::SigmoidOp
@@ -615,10 +621,9 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
615621
}
616622

617623
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
618-
Location loc, Operation *operation) {
619-
auto rank =
620-
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
621-
return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
624+
Location loc, ValueRange operands,
625+
int64_t rank) {
626+
return llvm::map_to_vector(operands, [&](Value operand) {
622627
return expandRank(rewriter, loc, operand, rank);
623628
});
624629
}
@@ -843,11 +848,16 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
843848
}
844849

845850
static LogicalResult
846-
emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
851+
emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
847852
Operation *operation, ValueRange operands,
848-
ArrayRef<OpFoldResult> targetShape) {
853+
ArrayRef<OpFoldResult> targetShape,
854+
const TypeConverter &converter) {
849855
// Generate output tensor
850-
auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
856+
auto resultType = cast_or_null<RankedTensorType>(
857+
converter.convertType(operation->getResultTypes().front()));
858+
if (!resultType) {
859+
return rewriter.notifyMatchFailure(operation, "failed to convert type");
860+
}
851861
Value outputTensor = rewriter.create<tensor::EmptyOp>(
852862
loc, targetShape, resultType.getElementType());
853863

@@ -894,8 +904,9 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
894904
}
895905

896906
static LogicalResult
897-
elementwiseMatchAndRewriteHelper(Operation *operation,
898-
PatternRewriter &rewriter) {
907+
elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
908+
ConversionPatternRewriter &rewriter,
909+
const TypeConverter &converter) {
899910

900911
// Collect op properties
901912
assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
@@ -908,13 +919,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
908919
// Lower operation
909920
IndexPool indexPool;
910921
auto loc = operation->getLoc();
911-
auto expandedOperands = expandInputRanks(rewriter, loc, operation);
922+
auto rank =
923+
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
924+
auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
912925
auto [targetShape, masterOperands] =
913926
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
914927
auto broadcastOperands = broadcastDynamicDimensions(
915928
rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
916929
return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
917-
targetShape);
930+
targetShape, converter);
918931
}
919932

920933
// Returns the constant initial value for a given reduction operation. The
@@ -1100,13 +1113,16 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
11001113
namespace {
11011114

11021115
template <typename SrcOp>
1103-
class PointwiseConverter : public OpRewritePattern<SrcOp> {
1116+
class PointwiseConverter : public OpConversionPattern<SrcOp> {
11041117
public:
1105-
using OpRewritePattern<SrcOp>::OpRewritePattern;
1118+
using OpConversionPattern<SrcOp>::OpConversionPattern;
1119+
using typename OpConversionPattern<SrcOp>::OpAdaptor;
11061120

1107-
LogicalResult matchAndRewrite(SrcOp op,
1108-
PatternRewriter &rewriter) const final {
1109-
return elementwiseMatchAndRewriteHelper(op, rewriter);
1121+
LogicalResult
1122+
matchAndRewrite(SrcOp op, OpAdaptor operands,
1123+
ConversionPatternRewriter &rewriter) const final {
1124+
return elementwiseMatchAndRewriteHelper(
1125+
op, operands.getOperands(), rewriter, *this->getTypeConverter());
11101126
}
11111127
};
11121128

@@ -1279,7 +1295,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
12791295
loc, nestedBuilder.getI32IntegerAttr(intMax));
12801296

12811297
value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1282-
nestedBuilder);
1298+
nestedBuilder, /*isUnsigned=*/false);
12831299

12841300
if (outIntType.getWidth() < 32) {
12851301
value = nestedBuilder.create<arith::TruncIOp>(
@@ -1643,7 +1659,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
16431659

16441660
auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
16451661
val = b.create<arith::AddIOp>(val, offset);
1646-
val = clampIntHelper(loc, val, zeroI32, max, b);
1662+
val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
16471663
return b.create<arith::IndexCastOp>(b.getIndexType(), val);
16481664
};
16491665

@@ -1664,8 +1680,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
16641680
Value max, ImplicitLocOpBuilder &b) {
16651681
val0 = in;
16661682
val1 = b.create<arith::AddIOp>(val0, oneVal);
1667-
val0 = clampIntHelper(loc, val0, zeroI32, max, b);
1668-
val1 = clampIntHelper(loc, val1, zeroI32, max, b);
1683+
val0 =
1684+
clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1685+
val1 =
1686+
clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
16691687
val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
16701688
val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
16711689
};
@@ -2552,7 +2570,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
25522570
} // namespace
25532571

25542572
void mlir::tosa::populateTosaToLinalgConversionPatterns(
2555-
RewritePatternSet *patterns) {
2573+
TypeConverter &converter, RewritePatternSet *patterns) {
25562574

25572575
// We have multiple resize coverters to handle degenerate cases.
25582576
patterns->add<GenericResizeConverter>(patterns->getContext(),
@@ -2599,7 +2617,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
25992617
PointwiseConverter<tosa::CeilOp>,
26002618
PointwiseConverter<tosa::FloorOp>,
26012619
PointwiseConverter<tosa::ClampOp>,
2602-
PointwiseConverter<tosa::SigmoidOp>,
2620+
PointwiseConverter<tosa::SigmoidOp>
2621+
>(converter, patterns->getContext());
2622+
2623+
patterns->add<
26032624
IdentityNConverter<tosa::IdentityOp>,
26042625
ReduceConverter<tosa::ReduceAllOp>,
26052626
ReduceConverter<tosa::ReduceAnyOp>,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
10151015
auto max = rewriter.create<arith::ConstantIntOp>(
10161016
loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
10171017
accETy);
1018-
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);
1018+
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
1019+
/*isUnsigned=*/false);
10191020

10201021
poolVal = clamp;
10211022
// Convert type.

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,11 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
6363

6464
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
6565

66+
TypeConverter converter;
67+
tosa::populateTosaToLinalgTypeConversion(converter);
68+
6669
FunctionOpInterface func = getOperation();
67-
mlir::tosa::populateTosaToLinalgConversionPatterns(&patterns);
70+
mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns);
6871
if (failed(applyFullConversion(func, target, std::move(patterns))))
6972
signalPassFailure();
7073
}

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
3838
}
3939

4040
Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
41-
OpBuilder &rewriter) {
41+
OpBuilder &rewriter, bool isUnsigned) {
42+
if (isUnsigned) {
43+
auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg);
44+
return rewriter.create<arith::MinUIOp>(loc, max, minOrArg);
45+
}
4246
auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
4347
return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
4448
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,12 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
2727
%2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<*xf32>) -> tensor<10x10xf32>
2828
return %2 : tensor<10x10xf32>
2929
}
30+
31+
// -----
32+
33+
// CHECK-LABEL: @clamp_on_large_int
34+
func.func @clamp_on_large_int(%arg0: tensor<1xui64>) -> tensor<1xui64> {
35+
// expected-error@+1 {{failed to legalize operation 'tosa.clamp'}}
36+
%0 = tosa.clamp %arg0 {min_int = -1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
37+
return %0 : tensor<1xui64>
38+
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ func.func @test_simple_ui8(%arg0: tensor<1xui8>) -> () {
606606
// -----
607607

608608
// CHECK-LABEL: @test_simple_i32
609-
func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
609+
func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>) -> () {
610610
// CHECK: linalg.generic
611611
// CHECK: arith.addi
612612
%0 = tosa.add %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
@@ -700,6 +700,13 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
700700
// CHECK-DAG: arith.minsi
701701
%19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
702702

703+
// CHECK: linalg.generic
704+
// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
705+
// CHECK-DAG: %[[UB:.*]] = arith.constant 5 : i32
706+
// CHECK-DAG: arith.maxui %[[LB]],
707+
// CHECK-DAG: arith.minui %[[UB]],
708+
%u19 = tosa.clamp %unsigned {min_int = -1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
709+
703710
// CHECK: linalg.generic
704711
// CHECK: arith.trunci
705712
%20 = tosa.cast %0 : (tensor<1xi32>) -> tensor<1xi16>

0 commit comments

Comments
 (0)