Skip to content

Commit dacdb3a

Browse files
mgehre-amdAlexisPerry
authored andcommitted
TosaToLinalg: Support unsigned tosa.clamp (llvm#91749)
This implements the lowering of tosa.clamp with unsigned operand to linalg. We interpret the `min/max : i64` attributes on `clamp` to be signed. This means that when the operand has type `ui64`, one cannot represent limits across the whole range.
1 parent 17fef60 commit dacdb3a

File tree

7 files changed

+105
-45
lines changed

7 files changed

+105
-45
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ void addTosaToLinalgPasses(
4747
void registerTosaToLinalgPipelines();
4848

4949
/// Populates conversion passes from TOSA dialect to Linalg dialect.
50-
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
50+
void populateTosaToLinalgConversionPatterns(TypeConverter &converter,
51+
RewritePatternSet *patterns);
5152

5253
/// Populates conversion passes from TOSA dialect to Linalg named operations.
5354
void populateTosaToLinalgNamedConversionPatterns(

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
3737
// Takes the parameters for a clamp and turns it into a series of ops for
3838
// integer inputs.
3939
Value clampIntHelper(Location loc, Value arg, Value min, Value max,
40-
OpBuilder &rewriter);
40+
OpBuilder &rewriter, bool isUnsigned);
4141

4242
// Determines whether the integer value falls witin the range of integer type.
4343
bool validIntegerRange(IntegerType ty, int64_t value);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 62 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,9 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
4646
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
4747
}
4848

49-
static Value
50-
createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
51-
ArrayRef<Type> resultTypes,
52-
PatternRewriter &rewriter) {
49+
static Value createLinalgBodyCalculationForElementwiseOp(
50+
Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51+
ConversionPatternRewriter &rewriter) {
5352
Location loc = op->getLoc();
5453
auto elementTy =
5554
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
@@ -186,7 +185,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
186185
Value max = rewriter.create<arith::ConstantIntOp>(
187186
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
188187
intermediateType);
189-
auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
188+
auto clamp =
189+
clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
190190

191191
// Truncate to the final value.
192192
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
@@ -389,25 +389,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
389389
int64_t max =
390390
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
391391

392+
int64_t minRepresentable = std::numeric_limits<int64_t>::min();
393+
int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
392394
if (intTy.isUnsignedInteger()) {
393-
min = std::max(min, (int64_t)0);
394-
max = std::min(
395-
max,
396-
APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
397-
} else {
398-
min =
399-
std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
400-
.getSExtValue());
401-
max =
402-
std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
403-
.getSExtValue());
395+
minRepresentable = 0;
396+
if (intTy.getIntOrFloatBitWidth() <= 63) {
397+
maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
398+
.getZExtValue();
399+
}
400+
} else if(intTy.getIntOrFloatBitWidth() <= 64) {
401+
// Ensure that min & max fit into signed n-bit constants.
402+
minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
403+
.getSExtValue();
404+
maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
405+
.getSExtValue();
404406
}
407+
// Ensure that the bounds are representable as n-bit signed/unsigned integers.
408+
min = std::max(min, minRepresentable);
409+
max = std::max(max, minRepresentable);
410+
min = std::min(min, maxRepresentable);
411+
max = std::min(max, maxRepresentable);
405412

406413
auto minVal = rewriter.create<arith::ConstantIntOp>(
407414
loc, min, intTy.getIntOrFloatBitWidth());
408415
auto maxVal = rewriter.create<arith::ConstantIntOp>(
409416
loc, max, intTy.getIntOrFloatBitWidth());
410-
return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
417+
return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
418+
intTy.isUnsignedInteger());
411419
}
412420

413421
// tosa::SigmoidOp
@@ -615,10 +623,9 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
615623
}
616624

617625
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
618-
Location loc, Operation *operation) {
619-
auto rank =
620-
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
621-
return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
626+
Location loc, ValueRange operands,
627+
int64_t rank) {
628+
return llvm::map_to_vector(operands, [&](Value operand) {
622629
return expandRank(rewriter, loc, operand, rank);
623630
});
624631
}
@@ -843,11 +850,16 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
843850
}
844851

845852
static LogicalResult
846-
emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
853+
emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
847854
Operation *operation, ValueRange operands,
848-
ArrayRef<OpFoldResult> targetShape) {
855+
ArrayRef<OpFoldResult> targetShape,
856+
const TypeConverter &converter) {
849857
// Generate output tensor
850-
auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
858+
auto resultType = cast_or_null<RankedTensorType>(
859+
converter.convertType(operation->getResultTypes().front()));
860+
if (!resultType) {
861+
return rewriter.notifyMatchFailure(operation, "failed to convert type");
862+
}
851863
Value outputTensor = rewriter.create<tensor::EmptyOp>(
852864
loc, targetShape, resultType.getElementType());
853865

@@ -894,8 +906,9 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
894906
}
895907

896908
static LogicalResult
897-
elementwiseMatchAndRewriteHelper(Operation *operation,
898-
PatternRewriter &rewriter) {
909+
elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
910+
ConversionPatternRewriter &rewriter,
911+
const TypeConverter &converter) {
899912

900913
// Collect op properties
901914
assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
@@ -908,13 +921,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
908921
// Lower operation
909922
IndexPool indexPool;
910923
auto loc = operation->getLoc();
911-
auto expandedOperands = expandInputRanks(rewriter, loc, operation);
924+
auto rank =
925+
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
926+
auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
912927
auto [targetShape, masterOperands] =
913928
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
914929
auto broadcastOperands = broadcastDynamicDimensions(
915930
rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
916931
return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
917-
targetShape);
932+
targetShape, converter);
918933
}
919934

920935
// Returns the constant initial value for a given reduction operation. The
@@ -1100,13 +1115,16 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
11001115
namespace {
11011116

11021117
template <typename SrcOp>
1103-
class PointwiseConverter : public OpRewritePattern<SrcOp> {
1118+
class PointwiseConverter : public OpConversionPattern<SrcOp> {
11041119
public:
1105-
using OpRewritePattern<SrcOp>::OpRewritePattern;
1120+
using OpConversionPattern<SrcOp>::OpConversionPattern;
1121+
using typename OpConversionPattern<SrcOp>::OpAdaptor;
11061122

1107-
LogicalResult matchAndRewrite(SrcOp op,
1108-
PatternRewriter &rewriter) const final {
1109-
return elementwiseMatchAndRewriteHelper(op, rewriter);
1123+
LogicalResult
1124+
matchAndRewrite(SrcOp op, OpAdaptor operands,
1125+
ConversionPatternRewriter &rewriter) const final {
1126+
return elementwiseMatchAndRewriteHelper(
1127+
op, operands.getOperands(), rewriter, *this->getTypeConverter());
11101128
}
11111129
};
11121130

@@ -1279,7 +1297,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
12791297
loc, nestedBuilder.getI32IntegerAttr(intMax));
12801298

12811299
value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1282-
nestedBuilder);
1300+
nestedBuilder, /*isUnsigned=*/false);
12831301

12841302
if (outIntType.getWidth() < 32) {
12851303
value = nestedBuilder.create<arith::TruncIOp>(
@@ -1643,7 +1661,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
16431661

16441662
auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
16451663
val = b.create<arith::AddIOp>(val, offset);
1646-
val = clampIntHelper(loc, val, zeroI32, max, b);
1664+
val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
16471665
return b.create<arith::IndexCastOp>(b.getIndexType(), val);
16481666
};
16491667

@@ -1664,8 +1682,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
16641682
Value max, ImplicitLocOpBuilder &b) {
16651683
val0 = in;
16661684
val1 = b.create<arith::AddIOp>(val0, oneVal);
1667-
val0 = clampIntHelper(loc, val0, zeroI32, max, b);
1668-
val1 = clampIntHelper(loc, val1, zeroI32, max, b);
1685+
val0 =
1686+
clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1687+
val1 =
1688+
clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
16691689
val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
16701690
val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
16711691
};
@@ -2555,7 +2575,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
25552575
} // namespace
25562576

25572577
void mlir::tosa::populateTosaToLinalgConversionPatterns(
2558-
RewritePatternSet *patterns) {
2578+
TypeConverter &converter, RewritePatternSet *patterns) {
25592579

25602580
// We have multiple resize coverters to handle degenerate cases.
25612581
patterns->add<GenericResizeConverter>(patterns->getContext(),
@@ -2602,7 +2622,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
26022622
PointwiseConverter<tosa::CeilOp>,
26032623
PointwiseConverter<tosa::FloorOp>,
26042624
PointwiseConverter<tosa::ClampOp>,
2605-
PointwiseConverter<tosa::SigmoidOp>,
2625+
PointwiseConverter<tosa::SigmoidOp>
2626+
>(converter, patterns->getContext());
2627+
2628+
patterns->add<
26062629
IdentityNConverter<tosa::IdentityOp>,
26072630
ReduceConverter<tosa::ReduceAllOp>,
26082631
ReduceConverter<tosa::ReduceAnyOp>,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
10151015
auto max = rewriter.create<arith::ConstantIntOp>(
10161016
loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
10171017
accETy);
1018-
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);
1018+
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
1019+
/*isUnsigned=*/false);
10191020

10201021
poolVal = clamp;
10211022
// Convert type.

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,11 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
6363

6464
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
6565

66+
TypeConverter converter;
67+
tosa::populateTosaTypeConversion(converter);
68+
6669
FunctionOpInterface func = getOperation();
67-
mlir::tosa::populateTosaToLinalgConversionPatterns(&patterns);
70+
mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns);
6871
if (failed(applyFullConversion(func, target, std::move(patterns))))
6972
signalPassFailure();
7073
}

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
3838
}
3939

4040
Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
41-
OpBuilder &rewriter) {
41+
OpBuilder &rewriter, bool isUnsigned) {
42+
if (isUnsigned) {
43+
auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg);
44+
return rewriter.create<arith::MinUIOp>(loc, max, minOrArg);
45+
}
4246
auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
4347
return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
4448
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ func.func @test_simple_ui8(%arg0: tensor<1xui8>) -> () {
606606
// -----
607607

608608
// CHECK-LABEL: @test_simple_i32
609-
func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
609+
func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %unsigned64: tensor<1xui64>) -> () {
610610
// CHECK: linalg.generic
611611
// CHECK: arith.addi
612612
%0 = tosa.add %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
@@ -700,6 +700,34 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
700700
// CHECK-DAG: arith.minsi
701701
%19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
702702

703+
// CHECK: linalg.generic
704+
// CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32
705+
// CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32
706+
// CHECK-DAG: arith.maxui %[[LB]],
707+
// CHECK-DAG: arith.minui %[[UB]],
708+
%u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
709+
710+
// CHECK: linalg.generic
711+
// CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32
712+
// CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32
713+
// CHECK-DAG: arith.maxui %[[LB]],
714+
// CHECK-DAG: arith.minui %[[UB]],
715+
%u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
716+
717+
// CHECK: linalg.generic
718+
// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
719+
// CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32
720+
// CHECK-DAG: arith.maxui %[[LB]],
721+
// CHECK-DAG: arith.minui %[[UB]],
722+
%u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
723+
724+
// CHECK: linalg.generic
725+
// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64
726+
// CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64
727+
// CHECK-DAG: arith.maxui %[[LB]],
728+
// CHECK-DAG: arith.minui %[[UB]],
729+
%u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
730+
703731
// CHECK: linalg.generic
704732
// CHECK: arith.trunci
705733
%20 = tosa.cast %0 : (tensor<1xi32>) -> tensor<1xi16>

0 commit comments

Comments
 (0)