Skip to content

Commit 40bf363

Browse files
[MLIR][Math] Add support for f16 in the expansion of math.roundeven
Add support for f16 in the expansion of math.roundeven. Associated GitHub issue: iree-org/iree#13522 This version addresses the build issues on Windows reported on https://reviews.llvm.org/D157204 Test plan: ninja check-mlir check-all Differential revision: https://reviews.llvm.org/D158234
1 parent 5a6c1ce commit 40bf363

File tree

8 files changed

+352
-71
lines changed

8 files changed

+352
-71
lines changed

mlir/include/mlir/ExecutionEngine/CRunnerUtils.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,6 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
469469
extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
470470
extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
471471
extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline();
472-
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF16(uint16_t bits); // bits!
473-
extern "C" MLIR_CRUNNERUTILS_EXPORT void printBF16(uint16_t bits); // bits!
474472

475473
//===----------------------------------------------------------------------===//
476474
// Small runtime support library for timing execution and printing GFLOPS

mlir/include/mlir/ExecutionEngine/Float16bits.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,8 @@ MLIR_FLOAT16_EXPORT std::ostream &operator<<(std::ostream &os, const f16 &f);
4848
// Outputs a bfloat value.
4949
MLIR_FLOAT16_EXPORT std::ostream &operator<<(std::ostream &os, const bf16 &d);
5050

51+
extern "C" MLIR_FLOAT16_EXPORT void printF16(uint16_t bits);
52+
extern "C" MLIR_FLOAT16_EXPORT void printBF16(uint16_t bits);
53+
5154
#undef MLIR_FLOAT16_EXPORT
5255
#endif // MLIR_EXECUTIONENGINE_FLOAT16BITS_H_

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class FloatType : public Type {
6767
unsigned getWidth();
6868

6969
/// Return the width of the mantissa of this type.
70+
/// The width includes the integer bit.
7071
unsigned getFPMantissaWidth();
7172

7273
/// Get or create a new FloatType with bitwidth scaled by `scale`.

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -305,31 +305,39 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
305305
Type operandETy = getElementTypeOrSelf(operandTy);
306306
Type resultETy = getElementTypeOrSelf(resultTy);
307307

308-
if (!operandETy.isF32() || !resultETy.isF32()) {
309-
return rewriter.notifyMatchFailure(op, "not a roundeven of f32.");
308+
if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
309+
return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
310310
}
311311

312-
Type i32Ty = b.getI32Type();
313-
Type f32Ty = b.getF32Type();
314-
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
315-
i32Ty = shapedTy.clone(i32Ty);
316-
f32Ty = shapedTy.clone(f32Ty);
312+
Type fTy = operandTy;
313+
Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
314+
if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
315+
iTy = shapedTy.clone(iTy);
317316
}
318317

319-
Value c1Float = createFloatConst(loc, f32Ty, 1.0, b);
320-
Value c0 = createIntConst(loc, i32Ty, 0, b);
321-
Value c1 = createIntConst(loc, i32Ty, 1, b);
322-
Value cNeg1 = createIntConst(loc, i32Ty, -1, b);
323-
Value c23 = createIntConst(loc, i32Ty, 23, b);
324-
Value c31 = createIntConst(loc, i32Ty, 31, b);
325-
Value c127 = createIntConst(loc, i32Ty, 127, b);
326-
Value c2To22 = createIntConst(loc, i32Ty, 1 << 22, b);
327-
Value c23Mask = createIntConst(loc, i32Ty, (1 << 23) - 1, b);
328-
Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
329-
330-
Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
318+
unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
319+
// The width returned by getFPMantissaWidth includes the integer bit.
320+
unsigned mantissaWidth =
321+
llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
322+
unsigned exponentWidth = bitWidth - mantissaWidth - 1;
323+
324+
// The names of the variables correspond to f32.
325+
// f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
326+
// f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
327+
Value c1Float = createFloatConst(loc, fTy, 1.0, b);
328+
Value c0 = createIntConst(loc, iTy, 0, b);
329+
Value c1 = createIntConst(loc, iTy, 1, b);
330+
Value cNeg1 = createIntConst(loc, iTy, -1, b);
331+
Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
332+
Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
333+
Value c127 = createIntConst(loc, iTy, (1 << (exponentWidth - 1)) - 1, b);
334+
Value c2To22 = createIntConst(loc, iTy, 1 << (mantissaWidth - 1), b);
335+
Value c23Mask = createIntConst(loc, iTy, (1 << mantissaWidth) - 1, b);
336+
Value expMask = createIntConst(loc, iTy, (1 << exponentWidth) - 1, b);
337+
338+
Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
331339
Value round = b.create<math::RoundOp>(operand);
332-
Value roundBitcast = b.create<arith::BitcastOp>(i32Ty, round);
340+
Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
333341

334342
// Get biased exponents for operand and round(operand)
335343
Value operandExp = b.create<arith::AndIOp>(
@@ -340,7 +348,7 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
340348
Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
341349

342350
auto safeShiftRight = [&](Value x, Value shift) -> Value {
343-
// Clamp shift to valid range [0, 31] to avoid undefined behavior
351+
// Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
344352
Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
345353
clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
346354
return b.create<arith::ShRUIOp>(x, clampedShift);

mlir/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ if(LLVM_ENABLE_PIC AND TARGET ${LLVM_NATIVE_ARCH})
119119
mlir-capi-execution-engine-test
120120
mlir_c_runner_utils
121121
mlir_runner_utils
122+
mlir_float16_utils
122123
)
123124
endif()
124125

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
232232

233233
// -----
234234

235-
// CHECK-LABEL: func.func @roundeven
236-
func.func @roundeven(%arg: f32) -> f32 {
235+
// CHECK-LABEL: func.func @roundeven32
236+
func.func @roundeven32(%arg: f32) -> f32 {
237237
%res = math.roundeven %arg : f32
238238
return %res : f32
239239
}
@@ -331,3 +331,90 @@ func.func @roundeven(%arg: f32) -> f32 {
331331
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f32
332332

333333
// CHECK: return %[[COPYSIGN]] : f32
334+
335+
// -----
336+
337+
// CHECK-LABEL: func.func @roundeven16
338+
func.func @roundeven16(%arg: f16) -> f16 {
339+
%res = math.roundeven %arg : f16
340+
return %res : f16
341+
}
342+
343+
// CHECK-SAME: %[[VAL_0:.*]]: f16) -> f16 {
344+
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : i16
345+
// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : i16
346+
// CHECK-DAG: %[[C_NEG_1:.*]] = arith.constant -1 : i16
347+
// CHECK-DAG: %[[C_1_FLOAT:.*]] = arith.constant 1.000000e+00 : f16
348+
// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i16
349+
// CHECK-DAG: %[[C_15:.*]] = arith.constant 15 : i16
350+
// CHECK-DAG: %[[C_512:.*]] = arith.constant 512 : i16
351+
// CHECK-DAG: %[[C_1023:.*]] = arith.constant 1023 : i16
352+
// CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 31 : i16
353+
354+
// CHECK: %[[OPERAND_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f16 to i16
355+
// CHECK: %[[ROUND:.*]] = math.round %[[VAL_0]] : f16
356+
// CHECK: %[[ROUND_BITCAST:.*]] = arith.bitcast %[[ROUND]] : f16 to i16
357+
358+
// Get biased exponents of `round` and `operand`
359+
// CHECK: %[[SHIFTED_OPERAND_BITCAST:.*]] = arith.shrui %[[OPERAND_BITCAST]], %[[C_10]] : i16
360+
// CHECK: %[[OPERAND_EXP:.*]] = arith.andi %[[SHIFTED_OPERAND_BITCAST]], %[[EXP_MASK]] : i16
361+
// CHECK: %[[OPERAND_BIASED_EXP:.*]] = arith.subi %[[OPERAND_EXP]], %[[C_15]] : i16
362+
// CHECK: %[[SHIFTED_ROUND_BITCAST:.*]] = arith.shrui %[[ROUND_BITCAST]], %[[C_10]] : i16
363+
// CHECK: %[[ROUND_EXP:.*]] = arith.andi %[[SHIFTED_ROUND_BITCAST]], %[[EXP_MASK]] : i16
364+
// CHECK: %[[ROUND_BIASED_EXP:.*]] = arith.subi %[[ROUND_EXP]], %[[C_15]] : i16
365+
366+
// Determine if `ROUND_BITCAST` is an even whole number or a special value
367+
// +-inf, +-nan.
368+
// Mask mantissa of `ROUND_BITCAST` with a mask shifted to the right by
369+
// `ROUND_BIASED_EXP - 1`
370+
// CHECK-DAG: %[[ROUND_BIASED_EXP_MINUS_1:.*]] = arith.subi %[[ROUND_BIASED_EXP]], %[[C_1]] : i16
371+
// CHECK-DAG: %[[CLAMPED_SHIFT_0:.*]] = arith.maxsi %[[ROUND_BIASED_EXP_MINUS_1]], %[[C_0]] : i16
372+
// CHECK-DAG: %[[CLAMPED_SHIFT_1:.*]] = arith.minsi %[[CLAMPED_SHIFT_0]], %[[C_15]] : i16
373+
// CHECK-DAG: %[[SHIFTED_MANTISSA_MASK_0:.*]] = arith.shrui %[[C_1023]], %[[CLAMPED_SHIFT_1]] : i16
374+
// CHECK-DAG: %[[ROUND_MASKED_MANTISSA:.*]] = arith.andi %[[ROUND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_0]] : i16
375+
376+
// `ROUND_BITCAST` is not even whole number or special value if masked
377+
// mantissa is != 0 or `ROUND_BIASED_EXP == 0`
378+
// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0:.*]] = arith.cmpi ne, %[[ROUND_MASKED_MANTISSA]], %[[C_0]] : i16
379+
// CHECK-DAG: %[[ROUND_BIASED_EXP_EQ_0:.*]] = arith.cmpi eq, %[[ROUND_BIASED_EXP]], %[[C_0]] : i16
380+
// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1:.*]] = arith.ori %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0]], %[[ROUND_BIASED_EXP_EQ_0]] : i1
381+
382+
// Determine if operand is halfway between two integer values
383+
// CHECK: %[[OPERAND_BIASED_EXP_EQ_NEG_1:.*]] = arith.cmpi eq, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i16
384+
// CHECK: %[[CLAMPED_SHIFT_2:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i16
385+
// CHECK: %[[CLAMPED_SHIFT_3:.*]] = arith.minsi %[[CLAMPED_SHIFT_2]], %[[C_15]] : i16
386+
// CHECK: %[[SHIFTED_2_TO_9:.*]] = arith.shrui %[[C_512]], %[[CLAMPED_SHIFT_3]] : i16
387+
388+
// A value with `0 <= BIASED_EXP < 10` is halfway between two consecutive
389+
// integers if the bit at index `BIASED_EXP` starting from the left in the
390+
// mantissa is 1 and all the bits to the right are zero. For the case where
391+
// `BIASED_EXP == -1, the expected mantissa is all zeros.
392+
// CHECK: %[[EXPECTED_OPERAND_MASKED_MANTISSA:.*]] = arith.select %[[OPERAND_BIASED_EXP_EQ_NEG_1]], %[[C_0]], %[[SHIFTED_2_TO_9]] : i16
393+
394+
// Mask mantissa of `OPERAND_BITCAST` with a mask shifted to the right by
395+
// `OPERAND_BIASED_EXP`
396+
// CHECK: %[[CLAMPED_SHIFT_4:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i16
397+
// CHECK: %[[CLAMPED_SHIFT_5:.*]] = arith.minsi %[[CLAMPED_SHIFT_4]], %[[C_15]] : i16
398+
// CHECK: %[[SHIFTED_MANTISSA_MASK_1:.*]] = arith.shrui %[[C_1023]], %[[CLAMPED_SHIFT_5]] : i16
399+
// CHECK: %[[OPERAND_MASKED_MANTISSA:.*]] = arith.andi %[[OPERAND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_1]] : i16
400+
401+
// The operand is halfway between two integers if the masked mantissa is equal
402+
// to the expected mantissa and the biased exponent is in the range
403+
// [-1, 23).
404+
// CHECK-DAG: %[[OPERAND_BIASED_EXP_GE_NEG_1:.*]] = arith.cmpi sge, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i16
405+
// CHECK-DAG: %[[OPERAND_BIASED_EXP_LT_10:.*]] = arith.cmpi slt, %[[OPERAND_BIASED_EXP]], %[[C_10]] : i16
406+
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_0:.*]] = arith.cmpi eq, %[[OPERAND_MASKED_MANTISSA]], %[[EXPECTED_OPERAND_MASKED_MANTISSA]] : i16
407+
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_1:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_0]], %[[OPERAND_BIASED_EXP_LT_10]] : i1
408+
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_2:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_1]], %[[OPERAND_BIASED_EXP_GE_NEG_1]] : i1
409+
410+
// Adjust rounded operand with `round(operand) - sign(operand)` to correct the
411+
// case where `round` rounded in the oppositve direction of `roundeven`.
412+
// CHECK: %[[SIGN:.*]] = math.copysign %[[C_1_FLOAT]], %[[VAL_0]] : f16
413+
// CHECK: %[[ROUND_SHIFTED:.*]] = arith.subf %[[ROUND]], %[[SIGN]] : f16
414+
// CHECK: %[[NEEDS_SHIFT:.*]] = arith.andi %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1]], %[[OPERAND_IS_HALFWAY_2]] : i1
415+
// CHECK: %[[RESULT:.*]] = arith.select %[[NEEDS_SHIFT]], %[[ROUND_SHIFTED]], %[[ROUND]] : f16
416+
417+
// The `x - sign` adjustment does not preserve the sign when we are adjusting the value -1 to -0.
418+
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
419+
420+
// CHECK: return %[[COPYSIGN]] : f16

mlir/test/lit.cfg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def add_runtime(name):
9898
add_runtime("mlir_runner_utils"),
9999
add_runtime("mlir_c_runner_utils"),
100100
add_runtime("mlir_async_runtime"),
101+
add_runtime("mlir_float16_utils"),
101102
"mlir-linalg-ods-yaml-gen",
102103
"mlir-reduce",
103104
"mlir-pdll",

0 commit comments

Comments
 (0)