Skip to content

Commit fe355a4

Browse files
[MLIR][Math] Add support for f64 in the expansion of math.roundeven
Add support for f64 in the expansion of math.roundeven. Associated GitHub issue: iree-org/iree#13522 This is based on the offline discussion and essentially recommits https://reviews.llvm.org/D158234. Test plan: ninja check-mlir check-all
1 parent 8a407a5 commit fe355a4

File tree

8 files changed

+429
-72
lines changed

8 files changed

+429
-72
lines changed

mlir/include/mlir/ExecutionEngine/CRunnerUtils.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,6 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
469469
extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
470470
extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
471471
extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline();
472-
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF16(uint16_t bits); // bits!
473-
extern "C" MLIR_CRUNNERUTILS_EXPORT void printBF16(uint16_t bits); // bits!
474472

475473
//===----------------------------------------------------------------------===//
476474
// Small runtime support library for timing execution and printing GFLOPS

mlir/include/mlir/ExecutionEngine/Float16bits.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,8 @@ MLIR_FLOAT16_EXPORT std::ostream &operator<<(std::ostream &os, const f16 &f);
4848
// Outputs a bfloat value.
4949
MLIR_FLOAT16_EXPORT std::ostream &operator<<(std::ostream &os, const bf16 &d);
5050

51+
extern "C" MLIR_FLOAT16_EXPORT void printF16(uint16_t bits);
52+
extern "C" MLIR_FLOAT16_EXPORT void printBF16(uint16_t bits);
53+
5154
#undef MLIR_FLOAT16_EXPORT
5255
#endif // MLIR_EXECUTIONENGINE_FLOAT16BITS_H_

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class FloatType : public Type {
6767
unsigned getWidth();
6868

6969
/// Return the width of the mantissa of this type.
70+
/// The width includes the integer bit.
7071
unsigned getFPMantissaWidth();
7172

7273
/// Get or create a new FloatType with bitwidth scaled by `scale`.

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -305,31 +305,40 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
305305
Type operandETy = getElementTypeOrSelf(operandTy);
306306
Type resultETy = getElementTypeOrSelf(resultTy);
307307

308-
if (!operandETy.isF32() || !resultETy.isF32()) {
309-
return rewriter.notifyMatchFailure(op, "not a roundeven of f32.");
308+
if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
309+
return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
310310
}
311311

312-
Type i32Ty = b.getI32Type();
313-
Type f32Ty = b.getF32Type();
314-
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
315-
i32Ty = shapedTy.clone(i32Ty);
316-
f32Ty = shapedTy.clone(f32Ty);
312+
Type fTy = operandTy;
313+
Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
314+
if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
315+
iTy = shapedTy.clone(iTy);
317316
}
318317

319-
Value c1Float = createFloatConst(loc, f32Ty, 1.0, b);
320-
Value c0 = createIntConst(loc, i32Ty, 0, b);
321-
Value c1 = createIntConst(loc, i32Ty, 1, b);
322-
Value cNeg1 = createIntConst(loc, i32Ty, -1, b);
323-
Value c23 = createIntConst(loc, i32Ty, 23, b);
324-
Value c31 = createIntConst(loc, i32Ty, 31, b);
325-
Value c127 = createIntConst(loc, i32Ty, 127, b);
326-
Value c2To22 = createIntConst(loc, i32Ty, 1 << 22, b);
327-
Value c23Mask = createIntConst(loc, i32Ty, (1 << 23) - 1, b);
328-
Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
329-
330-
Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
318+
unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
319+
// The width returned by getFPMantissaWidth includes the integer bit.
320+
unsigned mantissaWidth =
321+
llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
322+
unsigned exponentWidth = bitWidth - mantissaWidth - 1;
323+
324+
// The names of the variables correspond to f32.
325+
// f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
326+
// f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
327+
// f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
328+
Value c1Float = createFloatConst(loc, fTy, 1.0, b);
329+
Value c0 = createIntConst(loc, iTy, 0, b);
330+
Value c1 = createIntConst(loc, iTy, 1, b);
331+
Value cNeg1 = createIntConst(loc, iTy, -1, b);
332+
Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
333+
Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
334+
Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
335+
Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
336+
Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
337+
Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
338+
339+
Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
331340
Value round = b.create<math::RoundOp>(operand);
332-
Value roundBitcast = b.create<arith::BitcastOp>(i32Ty, round);
341+
Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
333342

334343
// Get biased exponents for operand and round(operand)
335344
Value operandExp = b.create<arith::AndIOp>(
@@ -340,7 +349,7 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
340349
Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
341350

342351
auto safeShiftRight = [&](Value x, Value shift) -> Value {
343-
// Clamp shift to valid range [0, 31] to avoid undefined behavior
352+
// Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
344353
Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
345354
clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
346355
return b.create<arith::ShRUIOp>(x, clampedShift);

mlir/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ if(LLVM_ENABLE_PIC AND TARGET ${LLVM_NATIVE_ARCH})
119119
mlir-capi-execution-engine-test
120120
mlir_c_runner_utils
121121
mlir_runner_utils
122+
mlir_float16_utils
122123
)
123124
endif()
124125

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 172 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,91 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
232232

233233
// -----
234234

235-
// CHECK-LABEL: func.func @roundeven
236-
func.func @roundeven(%arg: f32) -> f32 {
235+
// CHECK-LABEL: func.func @roundeven64
236+
func.func @roundeven64(%arg: f64) -> f64 {
237+
%res = math.roundeven %arg : f64
238+
return %res : f64
239+
}
240+
241+
// CHECK-SAME: %[[VAL_0:.*]]: f64) -> f64 {
242+
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : i64
243+
// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : i64
244+
// CHECK-DAG: %[[C_NEG_1:.*]] = arith.constant -1 : i64
245+
// CHECK-DAG: %[[C_1_FLOAT:.*]] = arith.constant 1.000000e+00 : f64
246+
// CHECK-DAG: %[[C_52:.*]] = arith.constant 52 : i64
247+
// CHECK-DAG: %[[C_63:.*]] = arith.constant 63 : i64
248+
// CHECK-DAG: %[[C_1023:.*]] = arith.constant 1023 : i64
249+
// CHECK-DAG: %[[C_2251799813685248:.*]] = arith.constant 2251799813685248 : i64
250+
// CHECK-DAG: %[[C_4503599627370495:.*]] = arith.constant 4503599627370495 : i64
251+
// CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 2047 : i64
252+
// CHECK: %[[OPERAND_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f64 to i64
253+
// CHECK: %[[ROUND:.*]] = math.round %[[VAL_0]] : f64
254+
// CHECK: %[[ROUND_BITCAST:.*]] = arith.bitcast %[[ROUND]] : f64 to i64
255+
256+
// Get biased exponents of `round` and `operand`
257+
// CHECK: %[[SHIFTED_OPERAND_BITCAST:.*]] = arith.shrui %[[OPERAND_BITCAST]], %[[C_52]] : i64
258+
// CHECK: %[[OPERAND_EXP:.*]] = arith.andi %[[SHIFTED_OPERAND_BITCAST]], %[[EXP_MASK]] : i64
259+
// CHECK: %[[OPERAND_BIASED_EXP:.*]] = arith.subi %[[OPERAND_EXP]], %[[C_1023]] : i64
260+
// CHECK: %[[SHIFTED_ROUND_BITCAST:.*]] = arith.shrui %[[ROUND_BITCAST]], %[[C_52]] : i64
261+
// CHECK: %[[ROUND_EXP:.*]] = arith.andi %[[SHIFTED_ROUND_BITCAST]], %[[EXP_MASK]] : i64
262+
// CHECK: %[[ROUND_BIASED_EXP:.*]] = arith.subi %[[ROUND_EXP]], %[[C_1023]] : i64
263+
264+
// Determine if `ROUND_BITCAST` is an even whole number or a special value
265+
// +-inf, +-nan.
266+
// Mask mantissa of `ROUND_BITCAST` with a mask shifted to the right by
267+
// `ROUND_BIASED_EXP - 1`
268+
// CHECK-DAG: %[[ROUND_BIASED_EXP_MINUS_1:.*]] = arith.subi %[[ROUND_BIASED_EXP]], %[[C_1]] : i64
269+
// CHECK-DAG: %[[CLAMPED_SHIFT_0:.*]] = arith.maxsi %[[ROUND_BIASED_EXP_MINUS_1]], %[[C_0]] : i64
270+
// CHECK-DAG: %[[CLAMPED_SHIFT_1:.*]] = arith.minsi %[[CLAMPED_SHIFT_0]], %[[C_63]] : i64
271+
// CHECK-DAG: %[[SHIFTED_MANTISSA_MASK_0:.*]] = arith.shrui %[[C_4503599627370495]], %[[CLAMPED_SHIFT_1]] : i64
272+
// CHECK-DAG: %[[ROUND_MASKED_MANTISSA:.*]] = arith.andi %[[ROUND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_0]] : i64
273+
274+
// `ROUND_BITCAST` is not even whole number or special value if masked
275+
// mantissa is != 0 or `ROUND_BIASED_EXP == 0`
276+
// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0:.*]] = arith.cmpi ne, %[[ROUND_MASKED_MANTISSA]], %[[C_0]] : i64
277+
// CHECK-DAG: %[[ROUND_BIASED_EXP_EQ_0:.*]] = arith.cmpi eq, %[[ROUND_BIASED_EXP]], %[[C_0]] : i64
278+
// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1:.*]] = arith.ori %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0]], %[[ROUND_BIASED_EXP_EQ_0]] : i1
279+
280+
// Determine if operand is halfway between two integer values
281+
// CHECK: %[[OPERAND_BIASED_EXP_EQ_NEG_1:.*]] = arith.cmpi eq, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i64
282+
// CHECK: %[[CLAMPED_SHIFT_2:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i64
283+
// CHECK: %[[CLAMPED_SHIFT_3:.*]] = arith.minsi %[[CLAMPED_SHIFT_2]], %[[C_63]] : i64
284+
// CHECK: %[[SHIFTED_2_TO_9:.*]] = arith.shrui %[[C_2251799813685248]], %[[CLAMPED_SHIFT_3]] : i64
285+
286+
// CHECK: %[[EXPECTED_OPERAND_MASKED_MANTISSA:.*]] = arith.select %[[OPERAND_BIASED_EXP_EQ_NEG_1]], %[[C_0]], %[[SHIFTED_2_TO_9]] : i64
287+
288+
// Mask mantissa of `OPERAND_BITCAST` with a mask shifted to the right by
289+
// `OPERAND_BIASED_EXP`
290+
// CHECK: %[[CLAMPED_SHIFT_4:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i64
291+
// CHECK: %[[CLAMPED_SHIFT_5:.*]] = arith.minsi %[[CLAMPED_SHIFT_4]], %[[C_63]] : i64
292+
// CHECK: %[[SHIFTED_MANTISSA_MASK_1:.*]] = arith.shrui %[[C_4503599627370495]], %[[CLAMPED_SHIFT_5]] : i64
293+
// CHECK: %[[OPERAND_MASKED_MANTISSA:.*]] = arith.andi %[[OPERAND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_1]] : i64
294+
295+
// The operand is halfway between two integers if the masked mantissa is equal
296+
// to the expected mantissa and the biased exponent is in the range
297+
// [-1, 52).
298+
// CHECK-DAG: %[[OPERAND_BIASED_EXP_GE_NEG_1:.*]] = arith.cmpi sge, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i64
299+
// CHECK-DAG: %[[OPERAND_BIASED_EXP_LT_10:.*]] = arith.cmpi slt, %[[OPERAND_BIASED_EXP]], %[[C_52]] : i64
300+
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_0:.*]] = arith.cmpi eq, %[[OPERAND_MASKED_MANTISSA]], %[[EXPECTED_OPERAND_MASKED_MANTISSA]] : i64
301+
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_1:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_0]], %[[OPERAND_BIASED_EXP_LT_10]] : i1
302+
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_2:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_1]], %[[OPERAND_BIASED_EXP_GE_NEG_1]] : i1
303+
304+
// Adjust rounded operand with `round(operand) - sign(operand)` to correct the
305+
// case where `round` rounded in the oppositve direction of `roundeven`.
306+
// CHECK: %[[SIGN:.*]] = math.copysign %[[C_1_FLOAT]], %[[VAL_0]] : f64
307+
// CHECK: %[[ROUND_SHIFTED:.*]] = arith.subf %[[ROUND]], %[[SIGN]] : f64
308+
// CHECK: %[[NEEDS_SHIFT:.*]] = arith.andi %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1]], %[[OPERAND_IS_HALFWAY_2]] : i1
309+
// CHECK: %[[RESULT:.*]] = arith.select %[[NEEDS_SHIFT]], %[[ROUND_SHIFTED]], %[[ROUND]] : f64
310+
311+
// The `x - sign` adjustment does not preserve the sign when we are adjusting the value -1 to -0.
312+
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f64
313+
314+
// CHECK: return %[[COPYSIGN]] : f64
315+
316+
// -----
317+
318+
// CHECK-LABEL: func.func @roundeven32
319+
func.func @roundeven32(%arg: f32) -> f32 {
237320
%res = math.roundeven %arg : f32
238321
return %res : f32
239322
}
@@ -331,3 +414,90 @@ func.func @roundeven(%arg: f32) -> f32 {
331414
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f32
332415

333416
// CHECK: return %[[COPYSIGN]] : f32
417+
418+
// -----
419+
420+
// CHECK-LABEL: func.func @roundeven16
421+
func.func @roundeven16(%arg: f16) -> f16 {
422+
%res = math.roundeven %arg : f16
423+
return %res : f16
424+
}
425+
426+
// CHECK-SAME: %[[VAL_0:.*]]: f16) -> f16 {
427+
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : i16
428+
// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : i16
429+
// CHECK-DAG: %[[C_NEG_1:.*]] = arith.constant -1 : i16
430+
// CHECK-DAG: %[[C_1_FLOAT:.*]] = arith.constant 1.000000e+00 : f16
431+
// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i16
432+
// CHECK-DAG: %[[C_15:.*]] = arith.constant 15 : i16
433+
// CHECK-DAG: %[[C_512:.*]] = arith.constant 512 : i16
434+
// CHECK-DAG: %[[C_1023:.*]] = arith.constant 1023 : i16
435+
// CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 31 : i16
436+
437+
// CHECK: %[[OPERAND_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f16 to i16
438+
// CHECK: %[[ROUND:.*]] = math.round %[[VAL_0]] : f16
439+
// CHECK: %[[ROUND_BITCAST:.*]] = arith.bitcast %[[ROUND]] : f16 to i16
440+
441+
// Get biased exponents of `round` and `operand`
442+
// CHECK: %[[SHIFTED_OPERAND_BITCAST:.*]] = arith.shrui %[[OPERAND_BITCAST]], %[[C_10]] : i16
443+
// CHECK: %[[OPERAND_EXP:.*]] = arith.andi %[[SHIFTED_OPERAND_BITCAST]], %[[EXP_MASK]] : i16
444+
// CHECK: %[[OPERAND_BIASED_EXP:.*]] = arith.subi %[[OPERAND_EXP]], %[[C_15]] : i16
445+
// CHECK: %[[SHIFTED_ROUND_BITCAST:.*]] = arith.shrui %[[ROUND_BITCAST]], %[[C_10]] : i16
446+
// CHECK: %[[ROUND_EXP:.*]] = arith.andi %[[SHIFTED_ROUND_BITCAST]], %[[EXP_MASK]] : i16
447+
// CHECK: %[[ROUND_BIASED_EXP:.*]] = arith.subi %[[ROUND_EXP]], %[[C_15]] : i16
448+
449+
// Determine if `ROUND_BITCAST` is an even whole number or a special value
450+
// +-inf, +-nan.
451+
// Mask mantissa of `ROUND_BITCAST` with a mask shifted to the right by
452+
// `ROUND_BIASED_EXP - 1`
453+
// CHECK-DAG: %[[ROUND_BIASED_EXP_MINUS_1:.*]] = arith.subi %[[ROUND_BIASED_EXP]], %[[C_1]] : i16
454+
// CHECK-DAG: %[[CLAMPED_SHIFT_0:.*]] = arith.maxsi %[[ROUND_BIASED_EXP_MINUS_1]], %[[C_0]] : i16
455+
// CHECK-DAG: %[[CLAMPED_SHIFT_1:.*]] = arith.minsi %[[CLAMPED_SHIFT_0]], %[[C_15]] : i16
456+
// CHECK-DAG: %[[SHIFTED_MANTISSA_MASK_0:.*]] = arith.shrui %[[C_1023]], %[[CLAMPED_SHIFT_1]] : i16
457+
// CHECK-DAG: %[[ROUND_MASKED_MANTISSA:.*]] = arith.andi %[[ROUND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_0]] : i16
458+
459+
// `ROUND_BITCAST` is not even whole number or special value if masked
460+
// mantissa is != 0 or `ROUND_BIASED_EXP == 0`
461+
// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0:.*]] = arith.cmpi ne, %[[ROUND_MASKED_MANTISSA]], %[[C_0]] : i16
462+
// CHECK-DAG: %[[ROUND_BIASED_EXP_EQ_0:.*]] = arith.cmpi eq, %[[ROUND_BIASED_EXP]], %[[C_0]] : i16
463+
// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1:.*]] = arith.ori %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0]], %[[ROUND_BIASED_EXP_EQ_0]] : i1
464+
465+
// Determine if operand is halfway between two integer values
466+
// CHECK: %[[OPERAND_BIASED_EXP_EQ_NEG_1:.*]] = arith.cmpi eq, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i16
467+
// CHECK: %[[CLAMPED_SHIFT_2:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i16
468+
// CHECK: %[[CLAMPED_SHIFT_3:.*]] = arith.minsi %[[CLAMPED_SHIFT_2]], %[[C_15]] : i16
469+
// CHECK: %[[SHIFTED_2_TO_9:.*]] = arith.shrui %[[C_512]], %[[CLAMPED_SHIFT_3]] : i16
470+
471+
// A value with `0 <= BIASED_EXP < 10` is halfway between two consecutive
472+
// integers if the bit at index `BIASED_EXP` starting from the left in the
473+
// mantissa is 1 and all the bits to the right are zero. For the case where
474+
// `BIASED_EXP == -1, the expected mantissa is all zeros.
475+
// CHECK: %[[EXPECTED_OPERAND_MASKED_MANTISSA:.*]] = arith.select %[[OPERAND_BIASED_EXP_EQ_NEG_1]], %[[C_0]], %[[SHIFTED_2_TO_9]] : i16
476+
477+
// Mask mantissa of `OPERAND_BITCAST` with a mask shifted to the right by
478+
// `OPERAND_BIASED_EXP`
479+
// CHECK: %[[CLAMPED_SHIFT_4:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i16
480+
// CHECK: %[[CLAMPED_SHIFT_5:.*]] = arith.minsi %[[CLAMPED_SHIFT_4]], %[[C_15]] : i16
481+
// CHECK: %[[SHIFTED_MANTISSA_MASK_1:.*]] = arith.shrui %[[C_1023]], %[[CLAMPED_SHIFT_5]] : i16
482+
// CHECK: %[[OPERAND_MASKED_MANTISSA:.*]] = arith.andi %[[OPERAND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_1]] : i16
483+
484+
// The operand is halfway between two integers if the masked mantissa is equal
485+
// to the expected mantissa and the biased exponent is in the range
486+
// [-1, 23).
487+
// CHECK-DAG: %[[OPERAND_BIASED_EXP_GE_NEG_1:.*]] = arith.cmpi sge, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i16
488+
// CHECK-DAG: %[[OPERAND_BIASED_EXP_LT_10:.*]] = arith.cmpi slt, %[[OPERAND_BIASED_EXP]], %[[C_10]] : i16
489+
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_0:.*]] = arith.cmpi eq, %[[OPERAND_MASKED_MANTISSA]], %[[EXPECTED_OPERAND_MASKED_MANTISSA]] : i16
490+
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_1:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_0]], %[[OPERAND_BIASED_EXP_LT_10]] : i1
491+
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_2:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_1]], %[[OPERAND_BIASED_EXP_GE_NEG_1]] : i1
492+
493+
// Adjust rounded operand with `round(operand) - sign(operand)` to correct the
494+
// case where `round` rounded in the oppositve direction of `roundeven`.
495+
// CHECK: %[[SIGN:.*]] = math.copysign %[[C_1_FLOAT]], %[[VAL_0]] : f16
496+
// CHECK: %[[ROUND_SHIFTED:.*]] = arith.subf %[[ROUND]], %[[SIGN]] : f16
497+
// CHECK: %[[NEEDS_SHIFT:.*]] = arith.andi %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1]], %[[OPERAND_IS_HALFWAY_2]] : i1
498+
// CHECK: %[[RESULT:.*]] = arith.select %[[NEEDS_SHIFT]], %[[ROUND_SHIFTED]], %[[ROUND]] : f16
499+
500+
// The `x - sign` adjustment does not preserve the sign when we are adjusting the value -1 to -0.
501+
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
502+
503+
// CHECK: return %[[COPYSIGN]] : f16

mlir/test/lit.cfg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def add_runtime(name):
9898
add_runtime("mlir_runner_utils"),
9999
add_runtime("mlir_c_runner_utils"),
100100
add_runtime("mlir_async_runtime"),
101+
add_runtime("mlir_float16_utils"),
101102
"mlir-linalg-ods-yaml-gen",
102103
"mlir-reduce",
103104
"mlir-pdll",

0 commit comments

Comments
 (0)