Skip to content

Commit d22883e

Browse files
Revert "[MLIR][Math] Add support for f16 in the expansion of math.roundeven"
This reverts commit 40bf363. The build bot ppc64le-mlir-rhel-test got broken by these changes, see https://lab.llvm.org/buildbot#builders/88/builds/61048 .
1 parent 3e3880e commit d22883e

File tree

8 files changed

+71
-352
lines changed

8 files changed

+71
-352
lines changed

mlir/include/mlir/ExecutionEngine/CRunnerUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,8 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
469469
extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
470470
extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
471471
extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline();
472+
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF16(uint16_t bits); // bits!
473+
extern "C" MLIR_CRUNNERUTILS_EXPORT void printBF16(uint16_t bits); // bits!
472474

473475
//===----------------------------------------------------------------------===//
474476
// Small runtime support library for timing execution and printing GFLOPS

mlir/include/mlir/ExecutionEngine/Float16bits.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,5 @@ MLIR_FLOAT16_EXPORT std::ostream &operator<<(std::ostream &os, const f16 &f);
4848
// Outputs a bfloat value.
4949
MLIR_FLOAT16_EXPORT std::ostream &operator<<(std::ostream &os, const bf16 &d);
5050

51-
extern "C" MLIR_FLOAT16_EXPORT void printF16(uint16_t bits);
52-
extern "C" MLIR_FLOAT16_EXPORT void printBF16(uint16_t bits);
53-
5451
#undef MLIR_FLOAT16_EXPORT
5552
#endif // MLIR_EXECUTIONENGINE_FLOAT16BITS_H_

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ class FloatType : public Type {
6767
unsigned getWidth();
6868

6969
/// Return the width of the mantissa of this type.
70-
/// The width includes the integer bit.
7170
unsigned getFPMantissaWidth();
7271

7372
/// Get or create a new FloatType with bitwidth scaled by `scale`.

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -305,39 +305,31 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
305305
Type operandETy = getElementTypeOrSelf(operandTy);
306306
Type resultETy = getElementTypeOrSelf(resultTy);
307307

308-
if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
309-
return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
308+
if (!operandETy.isF32() || !resultETy.isF32()) {
309+
return rewriter.notifyMatchFailure(op, "not a roundeven of f32.");
310310
}
311311

312-
Type fTy = operandTy;
313-
Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
314-
if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
315-
iTy = shapedTy.clone(iTy);
312+
Type i32Ty = b.getI32Type();
313+
Type f32Ty = b.getF32Type();
314+
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
315+
i32Ty = shapedTy.clone(i32Ty);
316+
f32Ty = shapedTy.clone(f32Ty);
316317
}
317318

318-
unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
319-
// The width returned by getFPMantissaWidth includes the integer bit.
320-
unsigned mantissaWidth =
321-
llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
322-
unsigned exponentWidth = bitWidth - mantissaWidth - 1;
323-
324-
// The names of the variables correspond to f32.
325-
// f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
326-
// f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
327-
Value c1Float = createFloatConst(loc, fTy, 1.0, b);
328-
Value c0 = createIntConst(loc, iTy, 0, b);
329-
Value c1 = createIntConst(loc, iTy, 1, b);
330-
Value cNeg1 = createIntConst(loc, iTy, -1, b);
331-
Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
332-
Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
333-
Value c127 = createIntConst(loc, iTy, (1 << (exponentWidth - 1)) - 1, b);
334-
Value c2To22 = createIntConst(loc, iTy, 1 << (mantissaWidth - 1), b);
335-
Value c23Mask = createIntConst(loc, iTy, (1 << mantissaWidth) - 1, b);
336-
Value expMask = createIntConst(loc, iTy, (1 << exponentWidth) - 1, b);
337-
338-
Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
319+
Value c1Float = createFloatConst(loc, f32Ty, 1.0, b);
320+
Value c0 = createIntConst(loc, i32Ty, 0, b);
321+
Value c1 = createIntConst(loc, i32Ty, 1, b);
322+
Value cNeg1 = createIntConst(loc, i32Ty, -1, b);
323+
Value c23 = createIntConst(loc, i32Ty, 23, b);
324+
Value c31 = createIntConst(loc, i32Ty, 31, b);
325+
Value c127 = createIntConst(loc, i32Ty, 127, b);
326+
Value c2To22 = createIntConst(loc, i32Ty, 1 << 22, b);
327+
Value c23Mask = createIntConst(loc, i32Ty, (1 << 23) - 1, b);
328+
Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
329+
330+
Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
339331
Value round = b.create<math::RoundOp>(operand);
340-
Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
332+
Value roundBitcast = b.create<arith::BitcastOp>(i32Ty, round);
341333

342334
// Get biased exponents for operand and round(operand)
343335
Value operandExp = b.create<arith::AndIOp>(
@@ -348,7 +340,7 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
348340
Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
349341

350342
auto safeShiftRight = [&](Value x, Value shift) -> Value {
351-
// Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
343+
// Clamp shift to valid range [0, 31] to avoid undefined behavior
352344
Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
353345
clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
354346
return b.create<arith::ShRUIOp>(x, clampedShift);

mlir/test/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ if(LLVM_ENABLE_PIC AND TARGET ${LLVM_NATIVE_ARCH})
119119
mlir-capi-execution-engine-test
120120
mlir_c_runner_utils
121121
mlir_runner_utils
122-
mlir_float16_utils
123122
)
124123
endif()
125124

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 2 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
232232

233233
// -----
234234

235-
// CHECK-LABEL: func.func @roundeven32
236-
func.func @roundeven32(%arg: f32) -> f32 {
235+
// CHECK-LABEL: func.func @roundeven
236+
func.func @roundeven(%arg: f32) -> f32 {
237237
%res = math.roundeven %arg : f32
238238
return %res : f32
239239
}
@@ -331,90 +331,3 @@ func.func @roundeven32(%arg: f32) -> f32 {
331331
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f32
332332

333333
// CHECK: return %[[COPYSIGN]] : f32
334-
335-
// -----
336-
337-
// CHECK-LABEL: func.func @roundeven16
338-
func.func @roundeven16(%arg: f16) -> f16 {
339-
%res = math.roundeven %arg : f16
340-
return %res : f16
341-
}
342-
343-
// CHECK-SAME: %[[VAL_0:.*]]: f16) -> f16 {
344-
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : i16
345-
// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : i16
346-
// CHECK-DAG: %[[C_NEG_1:.*]] = arith.constant -1 : i16
347-
// CHECK-DAG: %[[C_1_FLOAT:.*]] = arith.constant 1.000000e+00 : f16
348-
// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i16
349-
// CHECK-DAG: %[[C_15:.*]] = arith.constant 15 : i16
350-
// CHECK-DAG: %[[C_512:.*]] = arith.constant 512 : i16
351-
// CHECK-DAG: %[[C_1023:.*]] = arith.constant 1023 : i16
352-
// CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 31 : i16
353-
354-
// CHECK: %[[OPERAND_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f16 to i16
355-
// CHECK: %[[ROUND:.*]] = math.round %[[VAL_0]] : f16
356-
// CHECK: %[[ROUND_BITCAST:.*]] = arith.bitcast %[[ROUND]] : f16 to i16
357-
358-
// Get biased exponents of `round` and `operand`
359-
// CHECK: %[[SHIFTED_OPERAND_BITCAST:.*]] = arith.shrui %[[OPERAND_BITCAST]], %[[C_10]] : i16
360-
// CHECK: %[[OPERAND_EXP:.*]] = arith.andi %[[SHIFTED_OPERAND_BITCAST]], %[[EXP_MASK]] : i16
361-
// CHECK: %[[OPERAND_BIASED_EXP:.*]] = arith.subi %[[OPERAND_EXP]], %[[C_15]] : i16
362-
// CHECK: %[[SHIFTED_ROUND_BITCAST:.*]] = arith.shrui %[[ROUND_BITCAST]], %[[C_10]] : i16
363-
// CHECK: %[[ROUND_EXP:.*]] = arith.andi %[[SHIFTED_ROUND_BITCAST]], %[[EXP_MASK]] : i16
364-
// CHECK: %[[ROUND_BIASED_EXP:.*]] = arith.subi %[[ROUND_EXP]], %[[C_15]] : i16
365-
366-
// Determine if `ROUND_BITCAST` is an even whole number or a special value
367-
// +-inf, +-nan.
368-
// Mask mantissa of `ROUND_BITCAST` with a mask shifted to the right by
369-
// `ROUND_BIASED_EXP - 1`
370-
// CHECK-DAG: %[[ROUND_BIASED_EXP_MINUS_1:.*]] = arith.subi %[[ROUND_BIASED_EXP]], %[[C_1]] : i16
371-
// CHECK-DAG: %[[CLAMPED_SHIFT_0:.*]] = arith.maxsi %[[ROUND_BIASED_EXP_MINUS_1]], %[[C_0]] : i16
372-
// CHECK-DAG: %[[CLAMPED_SHIFT_1:.*]] = arith.minsi %[[CLAMPED_SHIFT_0]], %[[C_15]] : i16
373-
// CHECK-DAG: %[[SHIFTED_MANTISSA_MASK_0:.*]] = arith.shrui %[[C_1023]], %[[CLAMPED_SHIFT_1]] : i16
374-
// CHECK-DAG: %[[ROUND_MASKED_MANTISSA:.*]] = arith.andi %[[ROUND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_0]] : i16
375-
376-
// `ROUND_BITCAST` is not even whole number or special value if masked
377-
// mantissa is != 0 or `ROUND_BIASED_EXP == 0`
378-
// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0:.*]] = arith.cmpi ne, %[[ROUND_MASKED_MANTISSA]], %[[C_0]] : i16
379-
// CHECK-DAG: %[[ROUND_BIASED_EXP_EQ_0:.*]] = arith.cmpi eq, %[[ROUND_BIASED_EXP]], %[[C_0]] : i16
380-
// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1:.*]] = arith.ori %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0]], %[[ROUND_BIASED_EXP_EQ_0]] : i1
381-
382-
// Determine if operand is halfway between two integer values
383-
// CHECK: %[[OPERAND_BIASED_EXP_EQ_NEG_1:.*]] = arith.cmpi eq, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i16
384-
// CHECK: %[[CLAMPED_SHIFT_2:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i16
385-
// CHECK: %[[CLAMPED_SHIFT_3:.*]] = arith.minsi %[[CLAMPED_SHIFT_2]], %[[C_15]] : i16
386-
// CHECK: %[[SHIFTED_2_TO_9:.*]] = arith.shrui %[[C_512]], %[[CLAMPED_SHIFT_3]] : i16
387-
388-
// A value with `0 <= BIASED_EXP < 10` is halfway between two consecutive
389-
// integers if the bit at index `BIASED_EXP` starting from the left in the
390-
// mantissa is 1 and all the bits to the right are zero. For the case where
391-
// `BIASED_EXP == -1, the expected mantissa is all zeros.
392-
// CHECK: %[[EXPECTED_OPERAND_MASKED_MANTISSA:.*]] = arith.select %[[OPERAND_BIASED_EXP_EQ_NEG_1]], %[[C_0]], %[[SHIFTED_2_TO_9]] : i16
393-
394-
// Mask mantissa of `OPERAND_BITCAST` with a mask shifted to the right by
395-
// `OPERAND_BIASED_EXP`
396-
// CHECK: %[[CLAMPED_SHIFT_4:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i16
397-
// CHECK: %[[CLAMPED_SHIFT_5:.*]] = arith.minsi %[[CLAMPED_SHIFT_4]], %[[C_15]] : i16
398-
// CHECK: %[[SHIFTED_MANTISSA_MASK_1:.*]] = arith.shrui %[[C_1023]], %[[CLAMPED_SHIFT_5]] : i16
399-
// CHECK: %[[OPERAND_MASKED_MANTISSA:.*]] = arith.andi %[[OPERAND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_1]] : i16
400-
401-
// The operand is halfway between two integers if the masked mantissa is equal
402-
// to the expected mantissa and the biased exponent is in the range
403-
// [-1, 23).
404-
// CHECK-DAG: %[[OPERAND_BIASED_EXP_GE_NEG_1:.*]] = arith.cmpi sge, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i16
405-
// CHECK-DAG: %[[OPERAND_BIASED_EXP_LT_10:.*]] = arith.cmpi slt, %[[OPERAND_BIASED_EXP]], %[[C_10]] : i16
406-
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_0:.*]] = arith.cmpi eq, %[[OPERAND_MASKED_MANTISSA]], %[[EXPECTED_OPERAND_MASKED_MANTISSA]] : i16
407-
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_1:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_0]], %[[OPERAND_BIASED_EXP_LT_10]] : i1
408-
// CHECK-DAG: %[[OPERAND_IS_HALFWAY_2:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_1]], %[[OPERAND_BIASED_EXP_GE_NEG_1]] : i1
409-
410-
// Adjust rounded operand with `round(operand) - sign(operand)` to correct the
411-
// case where `round` rounded in the oppositve direction of `roundeven`.
412-
// CHECK: %[[SIGN:.*]] = math.copysign %[[C_1_FLOAT]], %[[VAL_0]] : f16
413-
// CHECK: %[[ROUND_SHIFTED:.*]] = arith.subf %[[ROUND]], %[[SIGN]] : f16
414-
// CHECK: %[[NEEDS_SHIFT:.*]] = arith.andi %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1]], %[[OPERAND_IS_HALFWAY_2]] : i1
415-
// CHECK: %[[RESULT:.*]] = arith.select %[[NEEDS_SHIFT]], %[[ROUND_SHIFTED]], %[[ROUND]] : f16
416-
417-
// The `x - sign` adjustment does not preserve the sign when we are adjusting the value -1 to -0.
418-
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
419-
420-
// CHECK: return %[[COPYSIGN]] : f16

mlir/test/lit.cfg.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def add_runtime(name):
9898
add_runtime("mlir_runner_utils"),
9999
add_runtime("mlir_c_runner_utils"),
100100
add_runtime("mlir_async_runtime"),
101-
add_runtime("mlir_float16_utils"),
102101
"mlir-linalg-ods-yaml-gen",
103102
"mlir-reduce",
104103
"mlir-pdll",

0 commit comments

Comments
 (0)