@@ -261,68 +261,62 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
261
261
return rewriter.notifyMatchFailure (op, " not a trunc of f32 to bf16." );
262
262
}
263
263
264
- Type i1Ty = b.getI1Type ();
265
264
Type i16Ty = b.getI16Type ();
266
265
Type i32Ty = b.getI32Type ();
267
266
Type f32Ty = b.getF32Type ();
268
267
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
269
- i1Ty = shapedTy.clone (i1Ty);
270
268
i16Ty = shapedTy.clone (i16Ty);
271
269
i32Ty = shapedTy.clone (i32Ty);
272
270
f32Ty = shapedTy.clone (f32Ty);
273
271
}
274
272
275
- Value bitcast = b.create <arith::BitcastOp>(i32Ty, operand);
276
-
277
- Value c23 = createConst (op.getLoc (), i32Ty, 23 , rewriter);
278
- Value c31 = createConst (op.getLoc (), i32Ty, 31 , rewriter);
279
- Value c23Mask = createConst (op.getLoc (), i32Ty, (1 << 23 ) - 1 , rewriter);
280
- Value expMask =
281
- createConst (op.getLoc (), i32Ty, ((1 << 8 ) - 1 ) << 23 , rewriter);
282
- Value expMax =
283
- createConst (op.getLoc (), i32Ty, ((1 << 8 ) - 2 ) << 23 , rewriter);
284
-
285
- // Grab the sign bit.
286
- Value sign = b.create <arith::ShRUIOp>(bitcast, c31);
287
-
288
- // Our mantissa rounding value depends on the sign bit and the last
289
- // truncated bit.
290
- Value cManRound = createConst (op.getLoc (), i32Ty, (1 << 15 ), rewriter);
291
- cManRound = b.create <arith::SubIOp>(cManRound, sign);
292
-
293
- // Grab out the mantissa and directly apply rounding.
294
- Value man = b.create <arith::AndIOp>(bitcast, c23Mask);
295
- Value manRound = b.create <arith::AddIOp>(man, cManRound);
296
-
297
- // Grab the overflow bit and shift right if we overflow.
298
- Value roundBit = b.create <arith::ShRUIOp>(manRound, c23);
299
- Value manNew = b.create <arith::ShRUIOp>(manRound, roundBit);
300
-
301
- // Grab the exponent and round using the mantissa's carry bit.
302
- Value exp = b.create <arith::AndIOp>(bitcast, expMask);
303
- Value expCarry = b.create <arith::AddIOp>(exp, manRound);
304
- expCarry = b.create <arith::AndIOp>(expCarry, expMask);
305
-
306
- // If the exponent is saturated, we keep the max value.
307
- Value expCmp =
308
- b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
309
- exp = b.create <arith::SelectOp>(expCmp, exp, expCarry);
310
-
311
- // If the exponent is max and we rolled over, keep the old mantissa.
312
- Value roundBitBool = b.create <arith::TruncIOp>(i1Ty, roundBit);
313
- Value keepOldMan = b.create <arith::AndIOp>(expCmp, roundBitBool);
314
- man = b.create <arith::SelectOp>(keepOldMan, man, manNew);
315
-
316
- // Assemble the now rounded f32 value (as an i32).
317
- Value rounded = b.create <arith::ShLIOp>(sign, c31);
318
- rounded = b.create <arith::OrIOp>(rounded, exp);
319
- rounded = b.create <arith::OrIOp>(rounded, man);
320
-
273
+ // Algorithm borrowed from this excellent code:
274
+ // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
275
+ // There is a magic idea there, to let the addition of the rounding_bias to
276
+ // the mantissa simply overflow into the exponent bits. It's a bit of an
277
+ // aggressive, obfuscating optimization, but it is well-tested code, and it
278
+ // results in more concise and efficient IR.
279
+ // The case of NaN is handled separately (see isNaN and the final select).
280
+ // The case of infinities is NOT handled separately, which deserves an
281
+ // explanation. As the encoding of infinities has zero mantissa, the
282
+ // rounding-bias addition never carries into the exponent so that just gets
283
+ // truncated away, and as bfloat16 and float32 have the same number of
284
+ // exponent bits, that simple truncation is the desired outcome for
285
+ // infinities.
286
+ Value isNan =
287
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
288
+ // Constant used to make the rounding bias.
289
+ Value c7FFF = createConst (op.getLoc (), i32Ty, 0x7fff , rewriter);
290
+ // Constant used to generate a quiet NaN.
291
+ Value c7FC0_i16 = createConst (op.getLoc (), i16Ty, 0x7fc0 , rewriter);
292
+ // Small constants used to address bits.
321
293
Value c16 = createConst (op.getLoc (), i32Ty, 16 , rewriter);
322
- Value shr = b.create <arith::ShRUIOp>(rounded, c16);
323
- Value trunc = b.create <arith::TruncIOp>(i16Ty, shr);
324
- Value result = b.create <arith::BitcastOp>(resultTy, trunc);
325
-
294
+ Value c1 = createConst (op.getLoc (), i32Ty, 1 , rewriter);
295
+ // Reinterpret the input f32 value as bits.
296
+ Value bitcast = b.create <arith::BitcastOp>(i32Ty, operand);
297
+ // Read bit 16 as a value in {0,1}.
298
+ Value bit16 =
299
+ b.create <arith::AndIOp>(b.create <arith::ShRUIOp>(bitcast, c16), c1);
300
+ // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
301
+ // on bit 16, implementing the tie-breaking "to nearest even".
302
+ Value roundingBias = b.create <arith::AddIOp>(bit16, c7FFF);
303
+ // Add the rounding bias. Generally we want this to be added to the
304
+ // mantissa, but nothing prevents this to from carrying into the exponent
305
+ // bits, which would feel like a bug, but this is the magic trick here:
306
+ // when that happens, the mantissa gets reset to zero and the exponent
307
+ // gets incremented by the carry... which is actually exactly what we
308
+ // want.
309
+ Value biased = b.create <arith::AddIOp>(bitcast, roundingBias);
310
+ // Now that the rounding-bias has been added, truncating the low bits
311
+ // yields the correctly rounded result.
312
+ Value biasedAndShifted = b.create <arith::ShRUIOp>(biased, c16);
313
+ Value normalCaseResult_i16 =
314
+ b.create <arith::TruncIOp>(i16Ty, biasedAndShifted);
315
+ // Select either the above-computed result, or a quiet NaN constant
316
+ // if the input was NaN.
317
+ Value select =
318
+ b.create <arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
319
+ Value result = b.create <arith::BitcastOp>(resultTy, select);
326
320
rewriter.replaceOp (op, result);
327
321
return success ();
328
322
}
0 commit comments