Skip to content

Commit a0c9b58

Browse files
committed
TosaToLinalg: Fixed tosa.cast for small unsigned integers
1 parent 2d89023 commit a0c9b58

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
469469
args.front(), zero);
470470
}
471471

472-
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
472+
if (dstTy.isSignlessInteger() &&
473+
arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
473474
auto intMin = rewriter.create<arith::ConstantOp>(
474475
loc, rewriter.getF32FloatAttr(
475476
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
@@ -484,16 +485,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
484485

485486
auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
486487

487-
if (dstTy.isUnsignedInteger()) {
488-
auto cast = rewriter.create<arith::FPToUIOp>(
489-
loc, rewriter.getIntegerType(dstTy.getIntOrFloatBitWidth()), clamped);
490-
return rewriter.create<UnrealizedConversionCastOp>(
491-
loc, dstTy, cast->getResult(0)).getResult(0);
492-
}
493-
494488
return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
495489
}
496490

491+
if (dstTy.isUnsignedInteger() &&
492+
arith::FPToUIOp::areCastCompatible(srcTy, dstTy)) {
493+
auto intMin = rewriter.create<arith::ConstantOp>(
494+
loc, rewriter.getF32FloatAttr(
495+
APInt::getMinValue(dstTy.getIntOrFloatBitWidth())
496+
.getZExtValue()));
497+
498+
auto intMax = rewriter.create<arith::ConstantOp>(
499+
loc, rewriter.getF32FloatAttr(
500+
APInt::getMaxValue(dstTy.getIntOrFloatBitWidth())
501+
.getZExtValue()));
502+
503+
auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
504+
505+
auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
506+
507+
auto cast = rewriter.create<arith::FPToUIOp>(
508+
loc, rewriter.getIntegerType(dstTy.getIntOrFloatBitWidth()), clamped);
509+
// arith is signless, so temporarily cast back to being unsigned.
510+
return rewriter
511+
.create<UnrealizedConversionCastOp>(loc, dstTy, cast->getResult(0))
512+
.getResult(0);
513+
}
514+
497515
// Casting to boolean, integers need to only be checked as not-equal to
498516
// zero.
499517
if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-ui3.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
func.func @test_cast(%arg0: tensor<1xf32>) -> tensor<1xui3> {
44
// CHECK: linalg.generic
5-
// CHECK: arith.constant -4.000000e+00
6-
// CHECK: arith.constant 3.000000e+00
5+
// CHECK: arith.constant 0.000000e+00
6+
// CHECK: arith.constant 7.000000e+00
77
// CHECK: math.roundeven
88
// CHECK: arith.minf
99
// CHECK: arith.maxf
10-
// CHECK: arith.fptoui
10+
// CHECK: arith.fptoui {{.*}} : f32 to i3
1111
// CHECK: builtin.unrealized_conversion_cast
1212
%1 = "tosa.cast"(%arg0) : (tensor<1xf32>) -> tensor<1xui3>
1313

0 commit comments

Comments
 (0)