-
Notifications
You must be signed in to change notification settings - Fork 14.3k
TosaToLinalg: Support unsigned tosa.clamp #91749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Plump the TypeConverter into PointwiseConverter, and emit unsigned comparisons when the input type is unsigned.
41f64ca
to
4b801e0
Compare
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Matthias Gehre (mgehre-amd) ChangesOn top of #91734 Full diff: https://github.com/llvm/llvm-project/pull/91749.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 67965e34d8a3d..c84e4f17c38d8 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -47,7 +47,8 @@ void addTosaToLinalgPasses(
void registerTosaToLinalgPipelines();
/// Populates conversion passes from TOSA dialect to Linalg dialect.
-void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToLinalgConversionPatterns(TypeConverter &converter,
+ RewritePatternSet *patterns);
/// Populates conversion passes from TOSA dialect to Linalg named operations.
void populateTosaToLinalgNamedConversionPatterns(
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ca59b221d03eb..ceab7d9c628a5 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -37,7 +37,7 @@ Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
// Takes the parameters for a clamp and turns it into a series of ops for
// integer inputs.
Value clampIntHelper(Location loc, Value arg, Value min, Value max,
- OpBuilder &rewriter);
+ OpBuilder &rewriter, bool isUnsigned);
// Determines whether the integer value falls witin the range of integer type.
bool validIntegerRange(IntegerType ty, int64_t value);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8ad8e41414656..1442f2ad72255 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -46,10 +46,9 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
-static Value
-createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
- ArrayRef<Type> resultTypes,
- PatternRewriter &rewriter) {
+static Value createLinalgBodyCalculationForElementwiseOp(
+ Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
+ ConversionPatternRewriter &rewriter) {
Location loc = op->getLoc();
auto elementTy =
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
@@ -186,7 +185,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
Value max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
intermediateType);
- auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
+ auto clamp =
+ clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
// Truncate to the final value.
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
@@ -389,25 +389,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
int64_t max =
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
+ int64_t minRepresentable = std::numeric_limits<int64_t>::min();
+ int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
if (intTy.isUnsignedInteger()) {
- min = std::max(min, (int64_t)0);
- max = std::min(
- max,
- APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
- } else {
- min =
- std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
- max =
- std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
+ minRepresentable = 0;
+ if (intTy.getIntOrFloatBitWidth() <= 63) {
+ maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
+ .getZExtValue();
+ }
+ } else if(intTy.getIntOrFloatBitWidth() <= 64) {
+ // Ensure that min & max fit into signed n-bit constants.
+ minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue();
+ maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue();
}
+ // Ensure that the bounds are representable as n-bit signed/unsigned integers.
+ min = std::max(min, minRepresentable);
+ max = std::max(max, minRepresentable);
+ min = std::min(min, maxRepresentable);
+ max = std::min(max, maxRepresentable);
auto minVal = rewriter.create<arith::ConstantIntOp>(
loc, min, intTy.getIntOrFloatBitWidth());
auto maxVal = rewriter.create<arith::ConstantIntOp>(
loc, max, intTy.getIntOrFloatBitWidth());
- return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
+ return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
+ intTy.isUnsignedInteger());
}
// tosa::SigmoidOp
@@ -615,10 +623,9 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
}
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
- Location loc, Operation *operation) {
- auto rank =
- cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
- return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
+ Location loc, ValueRange operands,
+ int64_t rank) {
+ return llvm::map_to_vector(operands, [&](Value operand) {
return expandRank(rewriter, loc, operand, rank);
});
}
@@ -843,11 +850,16 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
}
static LogicalResult
-emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
+emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
Operation *operation, ValueRange operands,
- ArrayRef<OpFoldResult> targetShape) {
+ ArrayRef<OpFoldResult> targetShape,
+ const TypeConverter &converter) {
// Generate output tensor
- auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
+ auto resultType = cast_or_null<RankedTensorType>(
+ converter.convertType(operation->getResultTypes().front()));
+ if (!resultType) {
+ return rewriter.notifyMatchFailure(operation, "failed to convert type");
+ }
Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, targetShape, resultType.getElementType());
@@ -894,8 +906,9 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
}
static LogicalResult
-elementwiseMatchAndRewriteHelper(Operation *operation,
- PatternRewriter &rewriter) {
+elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ const TypeConverter &converter) {
// Collect op properties
assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
@@ -908,13 +921,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
// Lower operation
IndexPool indexPool;
auto loc = operation->getLoc();
- auto expandedOperands = expandInputRanks(rewriter, loc, operation);
+ auto rank =
+ cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
+ auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
auto [targetShape, masterOperands] =
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
auto broadcastOperands = broadcastDynamicDimensions(
rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
- targetShape);
+ targetShape, converter);
}
// Returns the constant initial value for a given reduction operation. The
@@ -1100,13 +1115,16 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
namespace {
template <typename SrcOp>
-class PointwiseConverter : public OpRewritePattern<SrcOp> {
+class PointwiseConverter : public OpConversionPattern<SrcOp> {
public:
- using OpRewritePattern<SrcOp>::OpRewritePattern;
+ using OpConversionPattern<SrcOp>::OpConversionPattern;
+ using typename OpConversionPattern<SrcOp>::OpAdaptor;
- LogicalResult matchAndRewrite(SrcOp op,
- PatternRewriter &rewriter) const final {
- return elementwiseMatchAndRewriteHelper(op, rewriter);
+ LogicalResult
+ matchAndRewrite(SrcOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const final {
+ return elementwiseMatchAndRewriteHelper(
+ op, operands.getOperands(), rewriter, *this->getTypeConverter());
}
};
@@ -1279,7 +1297,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
loc, nestedBuilder.getI32IntegerAttr(intMax));
value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
- nestedBuilder);
+ nestedBuilder, /*isUnsigned=*/false);
if (outIntType.getWidth() < 32) {
value = nestedBuilder.create<arith::TruncIOp>(
@@ -1643,7 +1661,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
val = b.create<arith::AddIOp>(val, offset);
- val = clampIntHelper(loc, val, zeroI32, max, b);
+ val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
return b.create<arith::IndexCastOp>(b.getIndexType(), val);
};
@@ -1664,8 +1682,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value max, ImplicitLocOpBuilder &b) {
val0 = in;
val1 = b.create<arith::AddIOp>(val0, oneVal);
- val0 = clampIntHelper(loc, val0, zeroI32, max, b);
- val1 = clampIntHelper(loc, val1, zeroI32, max, b);
+ val0 =
+ clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
+ val1 =
+ clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
};
@@ -2555,7 +2575,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
} // namespace
void mlir::tosa::populateTosaToLinalgConversionPatterns(
- RewritePatternSet *patterns) {
+ TypeConverter &converter, RewritePatternSet *patterns) {
// We have multiple resize coverters to handle degenerate cases.
patterns->add<GenericResizeConverter>(patterns->getContext(),
@@ -2602,7 +2622,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
PointwiseConverter<tosa::CeilOp>,
PointwiseConverter<tosa::FloorOp>,
PointwiseConverter<tosa::ClampOp>,
- PointwiseConverter<tosa::SigmoidOp>,
+ PointwiseConverter<tosa::SigmoidOp>
+ >(converter, patterns->getContext());
+
+ patterns->add<
IdentityNConverter<tosa::IdentityOp>,
ReduceConverter<tosa::ReduceAllOp>,
ReduceConverter<tosa::ReduceAnyOp>,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index d8fb3abc0bef8..77c3d2e875791 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1015,7 +1015,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
auto max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
accETy);
- auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);
+ auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
+ /*isUnsigned=*/false);
poolVal = clamp;
// Convert type.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 8904e3253922c..44036d7c31a91 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -63,8 +63,11 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+ TypeConverter converter;
+ tosa::populateTosaTypeConversion(converter);
+
FunctionOpInterface func = getOperation();
- mlir::tosa::populateTosaToLinalgConversionPatterns(&patterns);
+ mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 4fc97115064f3..f276924a8a9f6 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -38,7 +38,11 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
}
Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
- OpBuilder &rewriter) {
+ OpBuilder &rewriter, bool isUnsigned) {
+ if (isUnsigned) {
+ auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg);
+ return rewriter.create<arith::MinUIOp>(loc, max, minOrArg);
+ }
auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 45b39f79a2a72..5187d79fd4c0b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -606,7 +606,7 @@ func.func @test_simple_ui8(%arg0: tensor<1xui8>) -> () {
// -----
// CHECK-LABEL: @test_simple_i32
-func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
+func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %unsigned64: tensor<1xui64>) -> () {
// CHECK: linalg.generic
// CHECK: arith.addi
%0 = tosa.add %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
@@ -700,6 +700,34 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK-DAG: arith.minsi
%19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: linalg.generic
+ // CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32
+ // CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32
+ // CHECK-DAG: arith.maxui %[[LB]],
+ // CHECK-DAG: arith.minui %[[UB]],
+ %u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
+
+ // CHECK: linalg.generic
+ // CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32
+ // CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32
+ // CHECK-DAG: arith.maxui %[[LB]],
+ // CHECK-DAG: arith.minui %[[UB]],
+ %u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
+
+ // CHECK: linalg.generic
+ // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
+ // CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32
+ // CHECK-DAG: arith.maxui %[[LB]],
+ // CHECK-DAG: arith.minui %[[UB]],
+ %u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
+
+ // CHECK: linalg.generic
+ // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64
+ // CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64
+ // CHECK-DAG: arith.maxui %[[LB]],
+ // CHECK-DAG: arith.minui %[[UB]],
+ %u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
+
// CHECK: linalg.generic
// CHECK: arith.trunci
%20 = tosa.cast %0 : (tensor<1xi32>) -> tensor<1xi16>
|
sjarus
approved these changes
Jun 27, 2024
lravenclaw
pushed a commit
to lravenclaw/llvm-project
that referenced
this pull request
Jul 3, 2024
This implements the lowering of tosa.clamp with unsigned operand to linalg. We interpret the `min/max : i64` attributes on `clamp` to be signed. This means that when the operand has type `ui64`, one cannot represent limits across the whole range.
AlexisPerry
pushed a commit
to llvm-project-tlp/llvm-project
that referenced
this pull request
Jul 9, 2024
This implements the lowering of tosa.clamp with unsigned operand to linalg. We interpret the `min/max : i64` attributes on `clamp` to be signed. This means that when the operand has type `ui64`, one cannot represent limits across the whole range.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This implements the lowering of tosa.clamp with unsigned operand to linalg.
We interpret the
min/max : i64
attributes onclamp
to be signed.This means that when the operand has type
ui64
, one cannot represent limits across the whole range.