Skip to content

Commit 9c7cde6

Browse files
authored
Fix the lowering of arith.truncf : f32 to bf16. (#83180)
This lowering was not correctly handling the case where saturation of the mantissa results in an increase of the exponent value. The new code borrows, with credit, the idea from https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79 and adds comments to explain the magic trick going on here and why it's correct. Hat tip to its original author, whom I believe to be @Maratyszcza. A testcase was also requiring a tie to be broken upwards in a case where "to nearest-even" required going downward. The fact that it used to pass suggests that there was another bug in the old code.
1 parent 2eb6398 commit 9c7cde6

File tree

3 files changed

+96
-94
lines changed

3 files changed

+96
-94
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 46 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -261,68 +261,62 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
261261
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
262262
}
263263

264-
Type i1Ty = b.getI1Type();
265264
Type i16Ty = b.getI16Type();
266265
Type i32Ty = b.getI32Type();
267266
Type f32Ty = b.getF32Type();
268267
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
269-
i1Ty = shapedTy.clone(i1Ty);
270268
i16Ty = shapedTy.clone(i16Ty);
271269
i32Ty = shapedTy.clone(i32Ty);
272270
f32Ty = shapedTy.clone(f32Ty);
273271
}
274272

275-
Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
276-
277-
Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter);
278-
Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter);
279-
Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter);
280-
Value expMask =
281-
createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter);
282-
Value expMax =
283-
createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter);
284-
285-
// Grab the sign bit.
286-
Value sign = b.create<arith::ShRUIOp>(bitcast, c31);
287-
288-
// Our mantissa rounding value depends on the sign bit and the last
289-
// truncated bit.
290-
Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter);
291-
cManRound = b.create<arith::SubIOp>(cManRound, sign);
292-
293-
// Grab out the mantissa and directly apply rounding.
294-
Value man = b.create<arith::AndIOp>(bitcast, c23Mask);
295-
Value manRound = b.create<arith::AddIOp>(man, cManRound);
296-
297-
// Grab the overflow bit and shift right if we overflow.
298-
Value roundBit = b.create<arith::ShRUIOp>(manRound, c23);
299-
Value manNew = b.create<arith::ShRUIOp>(manRound, roundBit);
300-
301-
// Grab the exponent and round using the mantissa's carry bit.
302-
Value exp = b.create<arith::AndIOp>(bitcast, expMask);
303-
Value expCarry = b.create<arith::AddIOp>(exp, manRound);
304-
expCarry = b.create<arith::AndIOp>(expCarry, expMask);
305-
306-
// If the exponent is saturated, we keep the max value.
307-
Value expCmp =
308-
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
309-
exp = b.create<arith::SelectOp>(expCmp, exp, expCarry);
310-
311-
// If the exponent is max and we rolled over, keep the old mantissa.
312-
Value roundBitBool = b.create<arith::TruncIOp>(i1Ty, roundBit);
313-
Value keepOldMan = b.create<arith::AndIOp>(expCmp, roundBitBool);
314-
man = b.create<arith::SelectOp>(keepOldMan, man, manNew);
315-
316-
// Assemble the now rounded f32 value (as an i32).
317-
Value rounded = b.create<arith::ShLIOp>(sign, c31);
318-
rounded = b.create<arith::OrIOp>(rounded, exp);
319-
rounded = b.create<arith::OrIOp>(rounded, man);
320-
273+
// Algorithm borrowed from this excellent code:
274+
// https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
275+
// There is a magic idea there, to let the addition of the rounding_bias to
276+
// the mantissa simply overflow into the exponent bits. It's a bit of an
277+
// aggressive, obfuscating optimization, but it is well-tested code, and it
278+
// results in more concise and efficient IR.
279+
// The case of NaN is handled separately (see isNaN and the final select).
280+
// The case of infinities is NOT handled separately, which deserves an
281+
// explanation. As the encoding of infinities has zero mantissa, the
282+
// rounding-bias addition never carries into the exponent so that just gets
283+
// truncated away, and as bfloat16 and float32 have the same number of
284+
// exponent bits, that simple truncation is the desired outcome for
285+
// infinities.
286+
Value isNan =
287+
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
288+
// Constant used to make the rounding bias.
289+
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
290+
// Constant used to generate a quiet NaN.
291+
Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
292+
// Small constants used to address bits.
321293
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
322-
Value shr = b.create<arith::ShRUIOp>(rounded, c16);
323-
Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
324-
Value result = b.create<arith::BitcastOp>(resultTy, trunc);
325-
294+
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
295+
// Reinterpret the input f32 value as bits.
296+
Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
297+
// Read bit 16 as a value in {0,1}.
298+
Value bit16 =
299+
b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
300+
// Determine the rounding bias to add as either 0x7fff or 0x8000 depending
301+
// on bit 16, implementing the tie-breaking "to nearest even".
302+
Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
303+
// Add the rounding bias. Generally we want this to be added to the
304+
// mantissa, but nothing prevents this to from carrying into the exponent
305+
// bits, which would feel like a bug, but this is the magic trick here:
306+
// when that happens, the mantissa gets reset to zero and the exponent
307+
// gets incremented by the carry... which is actually exactly what we
308+
// want.
309+
Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
310+
// Now that the rounding-bias has been added, truncating the low bits
311+
// yields the correctly rounded result.
312+
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
313+
Value normalCaseResult_i16 =
314+
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
315+
// Select either the above-computed result, or a quiet NaN constant
316+
// if the input was NaN.
317+
Value select =
318+
b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
319+
Value result = b.create<arith::BitcastOp>(resultTy, select);
326320
rewriter.replaceOp(op, result);
327321
return success();
328322
}

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -255,36 +255,21 @@ func.func @truncf_f32(%arg0 : f32) -> bf16 {
255255
}
256256

257257
// CHECK-LABEL: @truncf_f32
258-
259-
// CHECK-DAG: %[[C16:.+]] = arith.constant 16
260-
// CHECK-DAG: %[[C32768:.+]] = arith.constant 32768
261-
// CHECK-DAG: %[[C2130706432:.+]] = arith.constant 2130706432
262-
// CHECK-DAG: %[[C2139095040:.+]] = arith.constant 2139095040
263-
// CHECK-DAG: %[[C8388607:.+]] = arith.constant 8388607
264-
// CHECK-DAG: %[[C31:.+]] = arith.constant 31
265-
// CHECK-DAG: %[[C23:.+]] = arith.constant 23
266-
// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %arg0
267-
// CHECK-DAG: %[[SIGN:.+]] = arith.shrui %[[BITCAST:.+]], %[[C31]]
268-
// CHECK-DAG: %[[ROUND:.+]] = arith.subi %[[C32768]], %[[SIGN]]
269-
// CHECK-DAG: %[[MANTISSA:.+]] = arith.andi %[[BITCAST]], %[[C8388607]]
270-
// CHECK-DAG: %[[ROUNDED:.+]] = arith.addi %[[MANTISSA]], %[[ROUND]]
271-
// CHECK-DAG: %[[ROLL:.+]] = arith.shrui %[[ROUNDED]], %[[C23]]
272-
// CHECK-DAG: %[[SHR:.+]] = arith.shrui %[[ROUNDED]], %[[ROLL]]
273-
// CHECK-DAG: %[[EXP:.+]] = arith.andi %0, %[[C2139095040]]
274-
// CHECK-DAG: %[[EXPROUND:.+]] = arith.addi %[[EXP]], %[[ROUNDED]]
275-
// CHECK-DAG: %[[EXPROLL:.+]] = arith.andi %[[EXPROUND]], %[[C2139095040]]
276-
// CHECK-DAG: %[[EXPMAX:.+]] = arith.cmpi uge, %[[EXP]], %[[C2130706432]]
277-
// CHECK-DAG: %[[EXPNEW:.+]] = arith.select %[[EXPMAX]], %[[EXP]], %[[EXPROLL]]
278-
// CHECK-DAG: %[[OVERFLOW_B:.+]] = arith.trunci %[[ROLL]]
279-
// CHECK-DAG: %[[KEEP_MAN:.+]] = arith.andi %[[EXPMAX]], %[[OVERFLOW_B]]
280-
// CHECK-DAG: %[[MANNEW:.+]] = arith.select %[[KEEP_MAN]], %[[MANTISSA]], %[[SHR]]
281-
// CHECK-DAG: %[[NEWSIGN:.+]] = arith.shli %[[SIGN]], %[[C31]]
282-
// CHECK-DAG: %[[WITHEXP:.+]] = arith.ori %[[NEWSIGN]], %[[EXPNEW]]
283-
// CHECK-DAG: %[[WITHMAN:.+]] = arith.ori %[[WITHEXP]], %[[MANNEW]]
284-
// CHECK-DAG: %[[SHIFT:.+]] = arith.shrui %[[WITHMAN]], %[[C16]]
285-
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHIFT]]
286-
// CHECK-DAG: %[[RES:.+]] = arith.bitcast %[[TRUNC]]
287-
// CHECK: return %[[RES]]
258+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
259+
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32
260+
// CHECK-DAG: %[[C7FC0_i16:.+]] = arith.constant 32704 : i16
261+
// CHECK-DAG: %[[C7FFF:.+]] = arith.constant 32767 : i32
262+
// CHECK-DAG: %[[ISNAN:.+]] = arith.cmpf une, %arg0, %arg0 : f32
263+
// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
264+
// CHECK-DAG: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C16]] : i32
265+
// CHECK-DAG: %[[BIT16:.+]] = arith.andi %[[SHRUI]], %[[C1]] : i32
266+
// CHECK-DAG: %[[ROUNDING_BIAS:.+]] = arith.addi %[[BIT16]], %[[C7FFF]] : i32
267+
// CHECK-DAG: %[[BIASED:.+]] = arith.addi %[[BITCAST]], %[[ROUNDING_BIAS]] : i32
268+
// CHECK-DAG: %[[BIASED_SHIFTED:.+]] = arith.shrui %[[BIASED]], %[[C16]] : i32
269+
// CHECK-DAG: %[[NORMAL_CASE_RESULT_i16:.+]] = arith.trunci %[[BIASED_SHIFTED]] : i32 to i16
270+
// CHECK-DAG: %[[SELECT:.+]] = arith.select %[[ISNAN]], %[[C7FC0_i16]], %[[NORMAL_CASE_RESULT_i16]] : i16
271+
// CHECK-DAG: %[[RESULT:.+]] = arith.bitcast %[[SELECT]] : i16 to bf16
272+
// CHECK: return %[[RESULT]]
288273

289274
// -----
290275

mlir/test/mlir-cpu-runner/expand-arith-ops.mlir

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,21 @@ func.func @trunc_bf16(%a : f32) {
1313
}
1414

1515
func.func @main() {
16-
// CHECK: 1.00781
17-
%roundOneI = arith.constant 0x3f808000 : i32
18-
%roundOneF = arith.bitcast %roundOneI : i32 to f32
19-
call @trunc_bf16(%roundOneF): (f32) -> ()
16+
// Note: this is a tie (low 16 bits are 0x8000). We expect the rounding behavior
17+
// to break ties "to nearest-even", which in this case means downwards,
18+
// since bit 16 is not set.
19+
// CHECK: 1
20+
%value_1_00391_I = arith.constant 0x3f808000 : i32
21+
%value_1_00391_F = arith.bitcast %value_1_00391_I : i32 to f32
22+
call @trunc_bf16(%value_1_00391_F): (f32) -> ()
23+
24+
// Note: this is a tie (low 16 bits are 0x8000). We expect the rounding behavior
25+
// to break ties "to nearest-even", which in this case means upwards,
26+
// since bit 16 is set.
27+
// CHECK-NEXT: 1.01562
28+
%value_1_01172_I = arith.constant 0x3f818000 : i32
29+
%value_1_01172_F = arith.bitcast %value_1_01172_I : i32 to f32
30+
call @trunc_bf16(%value_1_01172_F): (f32) -> ()
2031

2132
// CHECK-NEXT: -1
2233
%noRoundNegOneI = arith.constant 0xbf808000 : i32
@@ -38,15 +49,27 @@ func.func @main() {
3849
%neginff = arith.bitcast %neginfi : i32 to f32
3950
call @trunc_bf16(%neginff): (f32) -> ()
4051

52+
// Note: this rounds upwards. As the mantissa was already saturated, this rounding
53+
// causes the exponent to be incremented. As the exponent was already the
54+
// maximum exponent value for finite values, this increment of the exponent
55+
// causes this to overflow to +inf.
56+
// CHECK-NEXT: inf
57+
%big_overflowing_i = arith.constant 0x7f7fffff : i32
58+
%big_overflowing_f = arith.bitcast %big_overflowing_i : i32 to f32
59+
call @trunc_bf16(%big_overflowing_f): (f32) -> ()
60+
61+
// Same as the previous testcase but negative.
62+
// CHECK-NEXT: -inf
63+
%negbig_overflowing_i = arith.constant 0xff7fffff : i32
64+
%negbig_overflowing_f = arith.bitcast %negbig_overflowing_i : i32 to f32
65+
call @trunc_bf16(%negbig_overflowing_f): (f32) -> ()
66+
67+
// In contrast to the previous two testcases, the upwards-rounding here
68+
// does not cause overflow.
4169
// CHECK-NEXT: 3.38953e+38
42-
%bigi = arith.constant 0x7f7fffff : i32
43-
%bigf = arith.bitcast %bigi : i32 to f32
44-
call @trunc_bf16(%bigf): (f32) -> ()
45-
46-
// CHECK-NEXT: -3.38953e+38
47-
%negbigi = arith.constant 0xff7fffff : i32
48-
%negbigf = arith.bitcast %negbigi : i32 to f32
49-
call @trunc_bf16(%negbigf): (f32) -> ()
70+
%big_nonoverflowing_i = arith.constant 0x7f7effff : i32
71+
%big_nonoverflowing_f = arith.bitcast %big_nonoverflowing_i : i32 to f32
72+
call @trunc_bf16(%big_nonoverflowing_f): (f32) -> ()
5073

5174
// CHECK-NEXT: 1.625
5275
%exprolli = arith.constant 0x3fcfffff : i32

0 commit comments

Comments
 (0)