Skip to content

Commit 1c5fd15

Browse files
committed
[mlir][Tosa] Allow non-fp32 tosa.cast to integers
Fix the lowering of tosa.cast to create attributes of the input source type when casting from floats to integers. This is motivated by the need to cast fp16 to i9, which we have encountered in certain quantized models. Reviewed By: eric-k256, jpienaar Differential Revision: https://reviews.llvm.org/D158738
1 parent bcc8811 commit 1c5fd15

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,12 +481,14 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
481481

482482
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
483483
auto intMin = rewriter.create<arith::ConstantOp>(
484-
loc, rewriter.getF32FloatAttr(
484+
loc, rewriter.getFloatAttr(
485+
getElementTypeOrSelf(srcTy),
485486
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
486487
.getSExtValue()));
487488

488489
auto intMax = rewriter.create<arith::ConstantOp>(
489-
loc, rewriter.getF32FloatAttr(
490+
loc, rewriter.getFloatAttr(
491+
getElementTypeOrSelf(srcTy),
490492
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
491493
.getSExtValue()));
492494

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,14 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
551551
// CHECK: arith.extf
552552
%0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
553553

554+
// CHECK: linalg.generic
555+
// CHECK: arith.constant -1.280000e+02
556+
// CHECK: arith.constant 1.270000e+02
557+
// CHECK: math.roundeven
558+
// CHECK: arith.minf
559+
// CHECK: arith.maxf
560+
// CHECK: arith.fptosi
561+
%1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
554562
return
555563
}
556564

0 commit comments

Comments
 (0)