Skip to content

TosaToLinalg: Support unsigned tosa.clamp #91749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 27, 2024

Conversation

mgehre-amd
Copy link
Contributor

@mgehre-amd mgehre-amd commented May 10, 2024

This implements the lowering of tosa.clamp with unsigned operand to linalg.

We interpret the min/max : i64 attributes on clamp to be signed.

This means that when the operand has type ui64, one cannot represent limits across the whole range.

@mgehre-amd mgehre-amd changed the title TosaToTensor: Support reshape on tensors of unsigned integer TosaToLinalg: Support unsigned tosa.clamp May 10, 2024
Plump the TypeConverter into PointwiseConverter,
and emit unsigned comparisons when the input type is unsigned.
@mgehre-amd mgehre-amd force-pushed the mgehre.tosa_unsigned_clamp branch from 41f64ca to 4b801e0 Compare June 20, 2024 22:02
@mgehre-amd mgehre-amd marked this pull request as ready for review June 20, 2024 22:03
@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2024

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Matthias Gehre (mgehre-amd)

Changes

On top of #91734


Full diff: https://github.com/llvm/llvm-project/pull/91749.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h (+2-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+1-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+62-39)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+2-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp (+4-1)
  • (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+5-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+29-1)
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 67965e34d8a3d..c84e4f17c38d8 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -47,7 +47,8 @@ void addTosaToLinalgPasses(
 void registerTosaToLinalgPipelines();
 
 /// Populates conversion passes from TOSA dialect to Linalg dialect.
-void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToLinalgConversionPatterns(TypeConverter &converter,
+                                            RewritePatternSet *patterns);
 
 /// Populates conversion passes from TOSA dialect to Linalg named operations.
 void populateTosaToLinalgNamedConversionPatterns(
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ca59b221d03eb..ceab7d9c628a5 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -37,7 +37,7 @@ Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
 // Takes the parameters for a clamp and turns it into a series of ops for
 // integer inputs.
 Value clampIntHelper(Location loc, Value arg, Value min, Value max,
-                     OpBuilder &rewriter);
+                     OpBuilder &rewriter, bool isUnsigned);
 
 // Determines whether the integer value falls witin the range of integer type.
 bool validIntegerRange(IntegerType ty, int64_t value);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8ad8e41414656..1442f2ad72255 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -46,10 +46,9 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
 }
 
-static Value
-createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
-                                            ArrayRef<Type> resultTypes,
-                                            PatternRewriter &rewriter) {
+static Value createLinalgBodyCalculationForElementwiseOp(
+    Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
+    ConversionPatternRewriter &rewriter) {
   Location loc = op->getLoc();
   auto elementTy =
       cast<ShapedType>(op->getOperand(0).getType()).getElementType();
@@ -186,7 +185,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     Value max = rewriter.create<arith::ConstantIntOp>(
         loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
         intermediateType);
-    auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
+    auto clamp =
+        clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
 
     // Truncate to the final value.
     return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
@@ -389,25 +389,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     int64_t max =
         cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
 
+    int64_t minRepresentable = std::numeric_limits<int64_t>::min();
+    int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
     if (intTy.isUnsignedInteger()) {
-      min = std::max(min, (int64_t)0);
-      max = std::min(
-          max,
-          APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
-    } else {
-      min =
-          std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
-                            .getSExtValue());
-      max =
-          std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
-                            .getSExtValue());
+      minRepresentable = 0;
+      if (intTy.getIntOrFloatBitWidth() <= 63) {
+        maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
+                          .getZExtValue();
+      }
+    } else if(intTy.getIntOrFloatBitWidth() <= 64) {
+      // Ensure that min & max fit into signed n-bit constants.
+      minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+                            .getSExtValue();
+      maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+                            .getSExtValue();
     }
+    // Ensure that the bounds are representable as n-bit signed/unsigned integers.
+    min = std::max(min, minRepresentable);
+    max = std::max(max, minRepresentable);
+    min = std::min(min, maxRepresentable);
+    max = std::min(max, maxRepresentable);
 
     auto minVal = rewriter.create<arith::ConstantIntOp>(
         loc, min, intTy.getIntOrFloatBitWidth());
     auto maxVal = rewriter.create<arith::ConstantIntOp>(
         loc, max, intTy.getIntOrFloatBitWidth());
-    return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
+    return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
+                          intTy.isUnsignedInteger());
   }
 
   // tosa::SigmoidOp
@@ -615,10 +623,9 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
 }
 
 static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
-                                           Location loc, Operation *operation) {
-  auto rank =
-      cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
-  return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
+                                           Location loc, ValueRange operands,
+                                           int64_t rank) {
+  return llvm::map_to_vector(operands, [&](Value operand) {
     return expandRank(rewriter, loc, operand, rank);
   });
 }
@@ -843,11 +850,16 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
 }
 
 static LogicalResult
-emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
+emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
                            Operation *operation, ValueRange operands,
-                           ArrayRef<OpFoldResult> targetShape) {
+                           ArrayRef<OpFoldResult> targetShape,
+                           const TypeConverter &converter) {
   // Generate output tensor
-  auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
+  auto resultType = cast_or_null<RankedTensorType>(
+      converter.convertType(operation->getResultTypes().front()));
+  if (!resultType) {
+    return rewriter.notifyMatchFailure(operation, "failed to convert type");
+  }
   Value outputTensor = rewriter.create<tensor::EmptyOp>(
       loc, targetShape, resultType.getElementType());
 
@@ -894,8 +906,9 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
 }
 
 static LogicalResult
-elementwiseMatchAndRewriteHelper(Operation *operation,
-                                 PatternRewriter &rewriter) {
+elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
+                                 ConversionPatternRewriter &rewriter,
+                                 const TypeConverter &converter) {
 
   // Collect op properties
   assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
@@ -908,13 +921,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
   // Lower operation
   IndexPool indexPool;
   auto loc = operation->getLoc();
-  auto expandedOperands = expandInputRanks(rewriter, loc, operation);
+  auto rank =
+      cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
+  auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
   auto [targetShape, masterOperands] =
       computeTargetShape(rewriter, loc, indexPool, expandedOperands);
   auto broadcastOperands = broadcastDynamicDimensions(
       rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
   return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
-                                    targetShape);
+                                    targetShape, converter);
 }
 
 // Returns the constant initial value for a given reduction operation. The
@@ -1100,13 +1115,16 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
 namespace {
 
 template <typename SrcOp>
-class PointwiseConverter : public OpRewritePattern<SrcOp> {
+class PointwiseConverter : public OpConversionPattern<SrcOp> {
 public:
-  using OpRewritePattern<SrcOp>::OpRewritePattern;
+  using OpConversionPattern<SrcOp>::OpConversionPattern;
+  using typename OpConversionPattern<SrcOp>::OpAdaptor;
 
-  LogicalResult matchAndRewrite(SrcOp op,
-                                PatternRewriter &rewriter) const final {
-    return elementwiseMatchAndRewriteHelper(op, rewriter);
+  LogicalResult
+  matchAndRewrite(SrcOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    return elementwiseMatchAndRewriteHelper(
+        op, operands.getOperands(), rewriter, *this->getTypeConverter());
   }
 };
 
@@ -1279,7 +1297,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               loc, nestedBuilder.getI32IntegerAttr(intMax));
 
           value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
-                                 nestedBuilder);
+                                 nestedBuilder, /*isUnsigned=*/false);
 
           if (outIntType.getWidth() < 32) {
             value = nestedBuilder.create<arith::TruncIOp>(
@@ -1643,7 +1661,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
 
           auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
           val = b.create<arith::AddIOp>(val, offset);
-          val = clampIntHelper(loc, val, zeroI32, max, b);
+          val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
           return b.create<arith::IndexCastOp>(b.getIndexType(), val);
         };
 
@@ -1664,8 +1682,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
                                   Value max, ImplicitLocOpBuilder &b) {
           val0 = in;
           val1 = b.create<arith::AddIOp>(val0, oneVal);
-          val0 = clampIntHelper(loc, val0, zeroI32, max, b);
-          val1 = clampIntHelper(loc, val1, zeroI32, max, b);
+          val0 =
+              clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
+          val1 =
+              clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
           val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
           val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
         };
@@ -2555,7 +2575,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgConversionPatterns(
-    RewritePatternSet *patterns) {
+    TypeConverter &converter, RewritePatternSet *patterns) {
 
   // We have multiple resize coverters to handle degenerate cases.
   patterns->add<GenericResizeConverter>(patterns->getContext(),
@@ -2602,7 +2622,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       PointwiseConverter<tosa::CeilOp>,
       PointwiseConverter<tosa::FloorOp>,
       PointwiseConverter<tosa::ClampOp>,
-      PointwiseConverter<tosa::SigmoidOp>,
+      PointwiseConverter<tosa::SigmoidOp>
+        >(converter, patterns->getContext());
+
+  patterns->add<
       IdentityNConverter<tosa::IdentityOp>,
       ReduceConverter<tosa::ReduceAllOp>,
       ReduceConverter<tosa::ReduceAnyOp>,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index d8fb3abc0bef8..77c3d2e875791 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1015,7 +1015,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             auto max = rewriter.create<arith::ConstantIntOp>(
                 loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
                 accETy);
-            auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);
+            auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
+                                        /*isUnsigned=*/false);
 
             poolVal = clamp;
             // Convert type.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 8904e3253922c..44036d7c31a91 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -63,8 +63,11 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
 
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
+    TypeConverter converter;
+    tosa::populateTosaTypeConversion(converter);
+
     FunctionOpInterface func = getOperation();
-    mlir::tosa::populateTosaToLinalgConversionPatterns(&patterns);
+    mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns);
     if (failed(applyFullConversion(func, target, std::move(patterns))))
       signalPassFailure();
   }
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 4fc97115064f3..f276924a8a9f6 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -38,7 +38,11 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
 }
 
 Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
-                                 OpBuilder &rewriter) {
+                                 OpBuilder &rewriter, bool isUnsigned) {
+  if (isUnsigned) {
+    auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg);
+    return rewriter.create<arith::MinUIOp>(loc, max, minOrArg);
+  }
   auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
   return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 45b39f79a2a72..5187d79fd4c0b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -606,7 +606,7 @@ func.func @test_simple_ui8(%arg0: tensor<1xui8>) -> () {
 // -----
 
 // CHECK-LABEL: @test_simple_i32
-func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
+func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %unsigned64: tensor<1xui64>) -> () {
   // CHECK: linalg.generic
   // CHECK: arith.addi
   %0 = tosa.add %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
@@ -700,6 +700,34 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK-DAG: arith.minsi
   %19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
+  // CHECK: linalg.generic
+  // CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32
+  // CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32
+  // CHECK-DAG: arith.maxui %[[LB]],
+  // CHECK-DAG: arith.minui %[[UB]],
+  %u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
+
+  // CHECK: linalg.generic
+  // CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32
+  // CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32
+  // CHECK-DAG: arith.maxui %[[LB]],
+  // CHECK-DAG: arith.minui %[[UB]],
+  %u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
+
+  // CHECK: linalg.generic
+  // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
+  // CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32
+  // CHECK-DAG: arith.maxui %[[LB]],
+  // CHECK-DAG: arith.minui %[[UB]],
+  %u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
+
+  // CHECK: linalg.generic
+  // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64
+  // CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64
+  // CHECK-DAG: arith.maxui %[[LB]],
+  // CHECK-DAG: arith.minui %[[UB]],
+  %u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
+
   // CHECK: linalg.generic
   // CHECK: arith.trunci
   %20 = tosa.cast %0 : (tensor<1xi32>) -> tensor<1xi16>

@mgehre-amd mgehre-amd requested review from leandron and sjarus June 20, 2024 22:05
@mgehre-amd mgehre-amd merged commit 8d23719 into llvm:main Jun 27, 2024
12 checks passed
@mgehre-amd mgehre-amd deleted the mgehre.tosa_unsigned_clamp branch June 27, 2024 16:05
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
This implements the lowering of tosa.clamp with unsigned operand to
linalg.

We interpret the `min/max : i64`  attributes on `clamp` to be signed.

This means that when the operand has type `ui64`, one cannot represent
limits across the whole range.
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
This implements the lowering of tosa.clamp with unsigned operand to
linalg.

We interpret the `min/max : i64`  attributes on `clamp` to be signed.

This means that when the operand has type `ui64`, one cannot represent
limits across the whole range.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants