Skip to content

Commit c2367d6

Browse files
authored
TosaToLinAlg: fix tosa.cast legalization of FP->Int for non FP32 types. (#45)
1 parent 20fa0e8 commit c2367d6

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,16 +471,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
471471
}
472472

473473
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
474-
auto intMin = rewriter.create<arith::ConstantOp>(
474+
Value intMin = rewriter.create<arith::ConstantOp>(
475475
loc, rewriter.getF32FloatAttr(
476476
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
477477
.getSExtValue()));
478478

479-
auto intMax = rewriter.create<arith::ConstantOp>(
479+
Value intMax = rewriter.create<arith::ConstantOp>(
480480
loc, rewriter.getF32FloatAttr(
481481
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
482482
.getSExtValue()));
483483

484+
// Since F32 constants are created, we may still need to convert them to
485+
// the correct type.
486+
auto convertType = [&](Type ty, Value arg) {
487+
auto argTy = arg.getType();
488+
bool bitExtend =
489+
argTy.getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth();
490+
if (ty != argTy) {
491+
if (!bitExtend)
492+
arg = rewriter.create<arith::TruncFOp>(loc, ty, arg);
493+
else
494+
arg = rewriter.create<arith::ExtFOp>(loc, ty, arg);
495+
}
496+
return arg;
497+
};
498+
intMin = convertType(srcTy, intMin);
499+
intMax = convertType(srcTy, intMax);
500+
484501
auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
485502

486503
auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,17 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
270270
// CHECK: arith.extf
271271
%0 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xf32>
272272

273+
// CHECK: linalg.generic
274+
// CHECK: %[[C_LOWEST:.+]] = arith.constant -2.14748365E+9
275+
// CHECK: %[[C_MAX:.+]] = arith.constant 2.14748365E+9
276+
// CHECK: arith.truncf %[[C_LOWEST]] : f32 to f16
277+
// CHECK: arith.truncf %[[C_MAX]] : f32 to f16
278+
// CHECK: math.roundeven
279+
// CHECK: arith.minf
280+
// CHECK: arith.maxf
281+
// CHECK: arith.fptosi
282+
%1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
283+
273284
return
274285
}
275286

0 commit comments

Comments
 (0)