Skip to content

Commit e9ef520

Browse files
committed
magic bf16
1 parent b0bae44 commit e9ef520

File tree

1 file changed

+22
-52
lines changed

1 file changed

+22
-52
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 22 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -261,68 +261,38 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
261261
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
262262
}
263263

264-
Type i1Ty = b.getI1Type();
265264
Type i16Ty = b.getI16Type();
266265
Type i32Ty = b.getI32Type();
267266
Type f32Ty = b.getF32Type();
268267
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
269-
i1Ty = shapedTy.clone(i1Ty);
270268
i16Ty = shapedTy.clone(i16Ty);
271269
i32Ty = shapedTy.clone(i32Ty);
272270
f32Ty = shapedTy.clone(f32Ty);
273271
}
274272

275-
Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
276-
277-
Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter);
278-
Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter);
279-
Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter);
280-
Value expMask =
281-
createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter);
282-
Value expMax =
283-
createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter);
284-
285-
// Grab the sign bit.
286-
Value sign = b.create<arith::ShRUIOp>(bitcast, c31);
287-
288-
// Our mantissa rounding value depends on the sign bit and the last
289-
// truncated bit.
290-
Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter);
291-
cManRound = b.create<arith::SubIOp>(cManRound, sign);
292-
293-
// Grab out the mantissa and directly apply rounding.
294-
Value man = b.create<arith::AndIOp>(bitcast, c23Mask);
295-
Value manRound = b.create<arith::AddIOp>(man, cManRound);
296-
297-
// Grab the overflow bit and shift right if we overflow.
298-
Value roundBit = b.create<arith::ShRUIOp>(manRound, c23);
299-
Value manNew = b.create<arith::ShRUIOp>(manRound, roundBit);
300-
301-
// Grab the exponent and round using the mantissa's carry bit.
302-
Value exp = b.create<arith::AndIOp>(bitcast, expMask);
303-
Value expCarry = b.create<arith::AddIOp>(exp, manRound);
304-
expCarry = b.create<arith::AndIOp>(expCarry, expMask);
305-
306-
// If the exponent is saturated, we keep the max value.
307-
Value expCmp =
308-
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
309-
exp = b.create<arith::SelectOp>(expCmp, exp, expCarry);
310-
311-
// If the exponent is max and we rolled over, keep the old mantissa.
312-
Value roundBitBool = b.create<arith::TruncIOp>(i1Ty, roundBit);
313-
Value keepOldMan = b.create<arith::AndIOp>(expCmp, roundBitBool);
314-
man = b.create<arith::SelectOp>(keepOldMan, man, manNew);
315-
316-
// Assemble the now rounded f32 value (as an i32).
317-
Value rounded = b.create<arith::ShLIOp>(sign, c31);
318-
rounded = b.create<arith::OrIOp>(rounded, exp);
319-
rounded = b.create<arith::OrIOp>(rounded, man);
320-
273+
// Algorithm borrowed from this excellent code:
274+
// https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
275+
// There is a magic idea there, to let the addition of the rounding_bias to
276+
// the mantissa simply overflow into the exponent bits. It's a bit of an
277+
// aggressive, obfuscating optimization, but it is well-tested code, and it
278+
// results in more concise and efficient IR.
279+
Value isNan =
280+
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
281+
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
282+
Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
321283
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
322-
Value shr = b.create<arith::ShRUIOp>(rounded, c16);
323-
Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
324-
Value result = b.create<arith::BitcastOp>(resultTy, trunc);
325-
284+
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
285+
Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
286+
Value bit16 =
287+
b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
288+
Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
289+
Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
290+
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
291+
Value normalCaseResult_i16 =
292+
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
293+
Value select =
294+
b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
295+
Value result = b.create<arith::BitcastOp>(resultTy, select);
326296
rewriter.replaceOp(op, result);
327297
return success();
328298
}

0 commit comments

Comments
 (0)