Skip to content

Fix the lowering of arith.truncf : f32 to bf16. #83180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 46 additions & 52 deletions mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,68 +261,62 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
}

Type i1Ty = b.getI1Type();
Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
i1Ty = shapedTy.clone(i1Ty);
i16Ty = shapedTy.clone(i16Ty);
i32Ty = shapedTy.clone(i32Ty);
f32Ty = shapedTy.clone(f32Ty);
}

Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);

Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter);
Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter);
Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter);
Value expMask =
createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter);
Value expMax =
createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter);

// Grab the sign bit.
Value sign = b.create<arith::ShRUIOp>(bitcast, c31);

// Our mantissa rounding value depends on the sign bit and the last
// truncated bit.
Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter);
cManRound = b.create<arith::SubIOp>(cManRound, sign);

// Grab out the mantissa and directly apply rounding.
Value man = b.create<arith::AndIOp>(bitcast, c23Mask);
Value manRound = b.create<arith::AddIOp>(man, cManRound);

// Grab the overflow bit and shift right if we overflow.
Value roundBit = b.create<arith::ShRUIOp>(manRound, c23);
Value manNew = b.create<arith::ShRUIOp>(manRound, roundBit);

// Grab the exponent and round using the mantissa's carry bit.
Value exp = b.create<arith::AndIOp>(bitcast, expMask);
Value expCarry = b.create<arith::AddIOp>(exp, manRound);
expCarry = b.create<arith::AndIOp>(expCarry, expMask);

// If the exponent is saturated, we keep the max value.
Value expCmp =
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
exp = b.create<arith::SelectOp>(expCmp, exp, expCarry);

// If the exponent is max and we rolled over, keep the old mantissa.
Value roundBitBool = b.create<arith::TruncIOp>(i1Ty, roundBit);
Value keepOldMan = b.create<arith::AndIOp>(expCmp, roundBitBool);
man = b.create<arith::SelectOp>(keepOldMan, man, manNew);

// Assemble the now rounded f32 value (as an i32).
Value rounded = b.create<arith::ShLIOp>(sign, c31);
rounded = b.create<arith::OrIOp>(rounded, exp);
rounded = b.create<arith::OrIOp>(rounded, man);

// Algorithm borrowed from this excellent code:
// https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
// There is a magic idea there, to let the addition of the rounding_bias to
// the mantissa simply overflow into the exponent bits. It's a bit of an
// aggressive, obfuscating optimization, but it is well-tested code, and it
// results in more concise and efficient IR.
// The case of NaN is handled separately (see isNaN and the final select).
// The case of infinities is NOT handled separately, which deserves an
// explanation. As the encoding of infinities has zero mantissa, the
// rounding-bias addition never carries into the exponent so that just gets
// truncated away, and as bfloat16 and float32 have the same number of
// exponent bits, that simple truncation is the desired outcome for
// infinities.
Value isNan =
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
// Constant used to make the rounding bias.
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
// Constant used to generate a quiet NaN.
Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
// Small constants used to address bits.
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
Value shr = b.create<arith::ShRUIOp>(rounded, c16);
Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
Value result = b.create<arith::BitcastOp>(resultTy, trunc);

Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
// Reinterpret the input f32 value as bits.
Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
// Read bit 16 as a value in {0,1}.
Value bit16 =
b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
// Determine the rounding bias to add as either 0x7fff or 0x8000 depending
// on bit 16, implementing the tie-breaking "to nearest even".
Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
// Add the rounding bias. Generally we want this to be added to the
// mantissa, but nothing prevents this to from carrying into the exponent
// bits, which would feel like a bug, but this is the magic trick here:
// when that happens, the mantissa gets reset to zero and the exponent
// gets incremented by the carry... which is actually exactly what we
// want.
Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
// Now that the rounding-bias has been added, truncating the low bits
// yields the correctly rounded result.
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
Value normalCaseResult_i16 =
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
// Select either the above-computed result, or a quiet NaN constant
// if the input was NaN.
Value select =
b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
Value result = b.create<arith::BitcastOp>(resultTy, select);
rewriter.replaceOp(op, result);
return success();
}
Expand Down
45 changes: 15 additions & 30 deletions mlir/test/Dialect/Arith/expand-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -255,36 +255,21 @@ func.func @truncf_f32(%arg0 : f32) -> bf16 {
}

// CHECK-LABEL: @truncf_f32

// CHECK-DAG: %[[C16:.+]] = arith.constant 16
// CHECK-DAG: %[[C32768:.+]] = arith.constant 32768
// CHECK-DAG: %[[C2130706432:.+]] = arith.constant 2130706432
// CHECK-DAG: %[[C2139095040:.+]] = arith.constant 2139095040
// CHECK-DAG: %[[C8388607:.+]] = arith.constant 8388607
// CHECK-DAG: %[[C31:.+]] = arith.constant 31
// CHECK-DAG: %[[C23:.+]] = arith.constant 23
// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %arg0
// CHECK-DAG: %[[SIGN:.+]] = arith.shrui %[[BITCAST:.+]], %[[C31]]
// CHECK-DAG: %[[ROUND:.+]] = arith.subi %[[C32768]], %[[SIGN]]
// CHECK-DAG: %[[MANTISSA:.+]] = arith.andi %[[BITCAST]], %[[C8388607]]
// CHECK-DAG: %[[ROUNDED:.+]] = arith.addi %[[MANTISSA]], %[[ROUND]]
// CHECK-DAG: %[[ROLL:.+]] = arith.shrui %[[ROUNDED]], %[[C23]]
// CHECK-DAG: %[[SHR:.+]] = arith.shrui %[[ROUNDED]], %[[ROLL]]
// CHECK-DAG: %[[EXP:.+]] = arith.andi %0, %[[C2139095040]]
// CHECK-DAG: %[[EXPROUND:.+]] = arith.addi %[[EXP]], %[[ROUNDED]]
// CHECK-DAG: %[[EXPROLL:.+]] = arith.andi %[[EXPROUND]], %[[C2139095040]]
// CHECK-DAG: %[[EXPMAX:.+]] = arith.cmpi uge, %[[EXP]], %[[C2130706432]]
// CHECK-DAG: %[[EXPNEW:.+]] = arith.select %[[EXPMAX]], %[[EXP]], %[[EXPROLL]]
// CHECK-DAG: %[[OVERFLOW_B:.+]] = arith.trunci %[[ROLL]]
// CHECK-DAG: %[[KEEP_MAN:.+]] = arith.andi %[[EXPMAX]], %[[OVERFLOW_B]]
// CHECK-DAG: %[[MANNEW:.+]] = arith.select %[[KEEP_MAN]], %[[MANTISSA]], %[[SHR]]
// CHECK-DAG: %[[NEWSIGN:.+]] = arith.shli %[[SIGN]], %[[C31]]
// CHECK-DAG: %[[WITHEXP:.+]] = arith.ori %[[NEWSIGN]], %[[EXPNEW]]
// CHECK-DAG: %[[WITHMAN:.+]] = arith.ori %[[WITHEXP]], %[[MANNEW]]
// CHECK-DAG: %[[SHIFT:.+]] = arith.shrui %[[WITHMAN]], %[[C16]]
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHIFT]]
// CHECK-DAG: %[[RES:.+]] = arith.bitcast %[[TRUNC]]
// CHECK: return %[[RES]]
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32
// CHECK-DAG: %[[C7FC0_i16:.+]] = arith.constant 32704 : i16
// CHECK-DAG: %[[C7FFF:.+]] = arith.constant 32767 : i32
// CHECK-DAG: %[[ISNAN:.+]] = arith.cmpf une, %arg0, %arg0 : f32
// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
// CHECK-DAG: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C16]] : i32
// CHECK-DAG: %[[BIT16:.+]] = arith.andi %[[SHRUI]], %[[C1]] : i32
// CHECK-DAG: %[[ROUNDING_BIAS:.+]] = arith.addi %[[BIT16]], %[[C7FFF]] : i32
// CHECK-DAG: %[[BIASED:.+]] = arith.addi %[[BITCAST]], %[[ROUNDING_BIAS]] : i32
// CHECK-DAG: %[[BIASED_SHIFTED:.+]] = arith.shrui %[[BIASED]], %[[C16]] : i32
// CHECK-DAG: %[[NORMAL_CASE_RESULT_i16:.+]] = arith.trunci %[[BIASED_SHIFTED]] : i32 to i16
// CHECK-DAG: %[[SELECT:.+]] = arith.select %[[ISNAN]], %[[C7FC0_i16]], %[[NORMAL_CASE_RESULT_i16]] : i16
// CHECK-DAG: %[[RESULT:.+]] = arith.bitcast %[[SELECT]] : i16 to bf16
// CHECK: return %[[RESULT]]

// -----

Expand Down
47 changes: 35 additions & 12 deletions mlir/test/mlir-cpu-runner/expand-arith-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,21 @@ func.func @trunc_bf16(%a : f32) {
}

func.func @main() {
// CHECK: 1.00781
%roundOneI = arith.constant 0x3f808000 : i32
%roundOneF = arith.bitcast %roundOneI : i32 to f32
call @trunc_bf16(%roundOneF): (f32) -> ()
// Note: this is a tie (low 16 bits are 0x8000). We expect the rounding behavior
// to break ties "to nearest-even", which in this case means downwards,
// since bit 16 is not set.
// CHECK: 1
%value_1_00391_I = arith.constant 0x3f808000 : i32
%value_1_00391_F = arith.bitcast %value_1_00391_I : i32 to f32
call @trunc_bf16(%value_1_00391_F): (f32) -> ()

// Note: this is a tie (low 16 bits are 0x8000). We expect the rounding behavior
// to break ties "to nearest-even", which in this case means upwards,
// since bit 16 is set.
// CHECK-NEXT: 1.01562
%value_1_01172_I = arith.constant 0x3f818000 : i32
%value_1_01172_F = arith.bitcast %value_1_01172_I : i32 to f32
call @trunc_bf16(%value_1_01172_F): (f32) -> ()

// CHECK-NEXT: -1
%noRoundNegOneI = arith.constant 0xbf808000 : i32
Expand All @@ -38,15 +49,27 @@ func.func @main() {
%neginff = arith.bitcast %neginfi : i32 to f32
call @trunc_bf16(%neginff): (f32) -> ()

// Note: this rounds upwards. As the mantissa was already saturated, this rounding
// causes the exponent to be incremented. As the exponent was already the
// maximum exponent value for finite values, this increment of the exponent
// causes this to overflow to +inf.
// CHECK-NEXT: inf
%big_overflowing_i = arith.constant 0x7f7fffff : i32
%big_overflowing_f = arith.bitcast %big_overflowing_i : i32 to f32
call @trunc_bf16(%big_overflowing_f): (f32) -> ()

// Same as the previous testcase but negative.
// CHECK-NEXT: -inf
%negbig_overflowing_i = arith.constant 0xff7fffff : i32
%negbig_overflowing_f = arith.bitcast %negbig_overflowing_i : i32 to f32
call @trunc_bf16(%negbig_overflowing_f): (f32) -> ()

// In contrast to the previous two testcases, the upwards-rounding here
// does not cause overflow.
// CHECK-NEXT: 3.38953e+38
%bigi = arith.constant 0x7f7fffff : i32
%bigf = arith.bitcast %bigi : i32 to f32
call @trunc_bf16(%bigf): (f32) -> ()

// CHECK-NEXT: -3.38953e+38
%negbigi = arith.constant 0xff7fffff : i32
%negbigf = arith.bitcast %negbigi : i32 to f32
call @trunc_bf16(%negbigf): (f32) -> ()
%big_nonoverflowing_i = arith.constant 0x7f7effff : i32
%big_nonoverflowing_f = arith.bitcast %big_nonoverflowing_i : i32 to f32
call @trunc_bf16(%big_nonoverflowing_f): (f32) -> ()

// CHECK-NEXT: 1.625
%exprolli = arith.constant 0x3fcfffff : i32
Expand Down