Skip to content

Commit 9e759f3

Browse files
committed
[AArch64] Fix fptoi/itofp for bf16
There were a number of issues that needed to be addressed: - i64 to bf16 did not correctly round - strict rounding needed to yield a chain - fastisel did not have logic to bail on bf16
1 parent 3f7aa04 commit 9e759f3

File tree

6 files changed

+672
-74
lines changed

6 files changed

+672
-74
lines changed

llvm/lib/Target/AArch64/AArch64FastISel.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,7 +2828,7 @@ bool AArch64FastISel::selectFPToInt(const Instruction *I, bool Signed) {
28282828
return false;
28292829

28302830
EVT SrcVT = TLI.getValueType(DL, I->getOperand(0)->getType(), true);
2831-
if (SrcVT == MVT::f128 || SrcVT == MVT::f16)
2831+
if (SrcVT == MVT::f128 || SrcVT == MVT::f16 || SrcVT == MVT::bf16)
28322832
return false;
28332833

28342834
unsigned Opc;
@@ -2856,7 +2856,7 @@ bool AArch64FastISel::selectIntToFP(const Instruction *I, bool Signed) {
28562856
if (!isTypeLegal(I->getType(), DestVT) || DestVT.isVector())
28572857
return false;
28582858
// Let regular ISEL handle FP16
2859-
if (DestVT == MVT::f16)
2859+
if (DestVT == MVT::f16 || DestVT == MVT::bf16)
28602860
return false;
28612861

28622862
assert((DestVT == MVT::f32 || DestVT == MVT::f64) &&
@@ -2978,7 +2978,7 @@ bool AArch64FastISel::fastLowerArguments() {
29782978
} else if (VT == MVT::i64) {
29792979
SrcReg = Registers[1][GPRIdx++];
29802980
RC = &AArch64::GPR64RegClass;
2981-
} else if (VT == MVT::f16) {
2981+
} else if (VT == MVT::f16 || VT == MVT::bf16) {
29822982
SrcReg = Registers[2][FPRIdx++];
29832983
RC = &AArch64::FPR16RegClass;
29842984
} else if (VT == MVT::f32) {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 112 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4121,14 +4121,16 @@ SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
41214121

41224122
// Now that we have rounded, shift the bits into position.
41234123
Narrow = DAG.getNode(ISD::SRL, dl, I32, Narrow,
4124-
DAG.getShiftAmountConstant(16, I32, dl));
4124+
DAG.getShiftAmountConstant(16, I32, dl));
41254125
if (VT.isVector()) {
41264126
EVT I16 = I32.changeVectorElementType(MVT::i16);
41274127
Narrow = DAG.getNode(ISD::TRUNCATE, dl, I16, Narrow);
41284128
return DAG.getNode(ISD::BITCAST, dl, VT, Narrow);
41294129
}
41304130
Narrow = DAG.getNode(ISD::BITCAST, dl, F32, Narrow);
4131-
return DAG.getTargetExtractSubreg(AArch64::hsub, dl, VT, Narrow);
4131+
SDValue Result = DAG.getTargetExtractSubreg(AArch64::hsub, dl, VT, Narrow);
4132+
return IsStrict ? DAG.getMergeValues({Result, Op.getOperand(0)}, dl)
4133+
: Result;
41324134
}
41334135

41344136
if (SrcVT != MVT::f128) {
@@ -4487,20 +4489,121 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
44874489
bool IsStrict = Op->isStrictFPOpcode();
44884490
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
44894491

4490-
// f16 conversions are promoted to f32 when full fp16 is not supported.
4491-
if ((Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) || Op.getValueType() == MVT::bf16) {
4492+
bool IsSigned = Op->getOpcode() == ISD::STRICT_SINT_TO_FP ||
4493+
Op->getOpcode() == ISD::SINT_TO_FP;
4494+
4495+
auto IntToFpViaPromotion = [&](EVT PromoteVT) {
44924496
SDLoc dl(Op);
44934497
if (IsStrict) {
4494-
SDValue Val = DAG.getNode(Op.getOpcode(), dl, {MVT::f32, MVT::Other},
4498+
SDValue Val = DAG.getNode(Op.getOpcode(), dl, {PromoteVT, MVT::Other},
44954499
{Op.getOperand(0), SrcVal});
44964500
return DAG.getNode(
44974501
ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
44984502
{Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
44994503
}
4500-
return DAG.getNode(
4501-
ISD::FP_ROUND, dl, Op.getValueType(),
4502-
DAG.getNode(Op.getOpcode(), dl, MVT::f32, SrcVal),
4503-
DAG.getIntPtrConstant(0, dl));
4504+
return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
4505+
DAG.getNode(Op.getOpcode(), dl, PromoteVT, SrcVal),
4506+
DAG.getIntPtrConstant(0, dl));
4507+
};
4508+
4509+
if (Op.getValueType() == MVT::bf16) {
4510+
// bf16 conversions are promoted to f32 when converting from i16.
4511+
if (DAG.ComputeMaxSignificantBits(SrcVal) <= 24) {
4512+
return IntToFpViaPromotion(MVT::f32);
4513+
}
4514+
4515+
// bf16 conversions are promoted to f64 when converting from i32.
4516+
if (DAG.ComputeMaxSignificantBits(SrcVal) <= 53) {
4517+
return IntToFpViaPromotion(MVT::f64);
4518+
}
4519+
4520+
// We need to be careful about i64 -> bf16.
4521+
// Consider an i32 22216703.
4522+
// This number cannot be represented exactly as an f32 and so a itofp will
4523+
// turn it into 22216704.0 fptrunc to bf16 will turn this into 22282240.0
4524+
// However, the correct bf16 was supposed to be 22151168.0
4525+
// We need to use sticky rounding to get this correct.
4526+
if (SrcVal.getValueType() == MVT::i64) {
4527+
SDLoc DL(Op);
4528+
// This algorithm is equivalent to the following:
4529+
// uint64_t SrcHi = SrcVal & ~0xfffull;
4530+
// uint64_t SrcLo = SrcVal & 0xfffull;
4531+
// uint64_t Highest = SrcVal >> 53;
4532+
// bool HasHighest = Highest != 0;
4533+
// uint64_t ToRound = HasHighest ? SrcHi : SrcVal;
4534+
// double Rounded = static_cast<double>(ToRound);
4535+
// uint64_t RoundedBits = std::bit_cast<uint64_t>(Rounded);
4536+
// uint64_t HasLo = SrcLo != 0;
4537+
// bool NeedsAdjustment = HasHighest & HasLo;
4538+
// uint64_t AdjustedBits = RoundedBits | uint64_t{NeedsAdjustment};
4539+
// double Adjusted = std::bit_cast<double>(AdjustedBits);
4540+
// return static_cast<__bf16>(Adjusted);
4541+
//
4542+
// Essentially, what happens is that SrcVal either fits perfectly in a
4543+
// double-precision value or it is too big. If it is sufficiently small,
4544+
// we should just go u64 -> double -> bf16 in a naive way. Otherwise, we
4545+
// ensure that u64 -> double has no rounding error by only using the 52
4546+
// MSB of the input. The low order bits will get merged into a sticky bit
4547+
// which will avoid issues incurred by double rounding.
4548+
4549+
// Signed conversion is more or less like so:
4550+
// copysign((__bf16)abs(SrcVal), SrcVal)
4551+
SDValue SignBit;
4552+
if (IsSigned) {
4553+
SignBit = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4554+
DAG.getConstant(1ull << 63, DL, MVT::i64));
4555+
SrcVal = DAG.getNode(ISD::ABS, DL, MVT::i64, SrcVal);
4556+
}
4557+
SDValue SrcHi = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4558+
DAG.getConstant(~0xfffull, DL, MVT::i64));
4559+
SDValue SrcLo = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4560+
DAG.getConstant(0xfffull, DL, MVT::i64));
4561+
SDValue Highest =
4562+
DAG.getNode(ISD::SRL, DL, MVT::i64, SrcVal,
4563+
DAG.getShiftAmountConstant(53, MVT::i64, DL));
4564+
SDValue Zero64 = DAG.getConstant(0, DL, MVT::i64);
4565+
SDValue ToRound =
4566+
DAG.getSelectCC(DL, Highest, Zero64, SrcHi, SrcVal, ISD::SETNE);
4567+
SDValue Rounded =
4568+
IsStrict ? DAG.getNode(Op.getOpcode(), DL, {MVT::f64, MVT::Other},
4569+
{Op.getOperand(0), ToRound})
4570+
: DAG.getNode(Op.getOpcode(), DL, MVT::f64, ToRound);
4571+
4572+
SDValue RoundedBits = DAG.getNode(ISD::BITCAST, DL, MVT::i64, Rounded);
4573+
if (SignBit) {
4574+
RoundedBits = DAG.getNode(ISD::OR, DL, MVT::i64, RoundedBits, SignBit);
4575+
}
4576+
4577+
SDValue HasHighest = DAG.getSetCC(
4578+
DL,
4579+
getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i64),
4580+
Highest, Zero64, ISD::SETNE);
4581+
4582+
SDValue HasLo = DAG.getSetCC(
4583+
DL,
4584+
getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i64),
4585+
SrcLo, Zero64, ISD::SETNE);
4586+
4587+
SDValue NeedsAdjustment =
4588+
DAG.getNode(ISD::AND, DL, HasLo.getValueType(), HasHighest, HasLo);
4589+
NeedsAdjustment = DAG.getZExtOrTrunc(NeedsAdjustment, DL, MVT::i64);
4590+
4591+
SDValue AdjustedBits =
4592+
DAG.getNode(ISD::OR, DL, MVT::i64, RoundedBits, NeedsAdjustment);
4593+
SDValue Adjusted = DAG.getNode(ISD::BITCAST, DL, MVT::f64, AdjustedBits);
4594+
return IsStrict
4595+
? DAG.getNode(ISD::STRICT_FP_ROUND, DL,
4596+
{Op.getValueType(), MVT::Other},
4597+
{Rounded.getValue(1), Adjusted,
4598+
DAG.getIntPtrConstant(0, DL)})
4599+
: DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), Adjusted,
4600+
DAG.getIntPtrConstant(0, DL, true));
4601+
}
4602+
}
4603+
4604+
// f16 conversions are promoted to f32 when full fp16 is not supported.
4605+
if (Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
4606+
return IntToFpViaPromotion(MVT::f32);
45044607
}
45054608

45064609
// i128 conversions are libcalls.

llvm/test/CodeGen/AArch64/arm64-convert-v4f64.ll

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,30 @@ define <4 x half> @uitofp_v4i64_to_v4f16(ptr %ptr) {
5454
ret <4 x half> %tmp2
5555
}
5656

57+
define <4 x bfloat> @uitofp_v4i64_to_v4bf16(ptr %ptr) {
58+
; CHECK-LABEL: uitofp_v4i64_to_v4bf16:
59+
; CHECK: // %bb.0:
60+
; CHECK-NEXT: ldp q0, q1, [x0]
61+
; CHECK-NEXT: movi v2.4s, #1
62+
; CHECK-NEXT: ucvtf v0.2d, v0.2d
63+
; CHECK-NEXT: ucvtf v1.2d, v1.2d
64+
; CHECK-NEXT: fcvtn v0.2s, v0.2d
65+
; CHECK-NEXT: fcvtn2 v0.4s, v1.2d
66+
; CHECK-NEXT: movi v1.4s, #127, msl #8
67+
; CHECK-NEXT: ushr v3.4s, v0.4s, #16
68+
; CHECK-NEXT: add v1.4s, v0.4s, v1.4s
69+
; CHECK-NEXT: and v2.16b, v3.16b, v2.16b
70+
; CHECK-NEXT: add v1.4s, v2.4s, v1.4s
71+
; CHECK-NEXT: fcmeq v2.4s, v0.4s, v0.4s
72+
; CHECK-NEXT: orr v0.4s, #64, lsl #16
73+
; CHECK-NEXT: bit v0.16b, v1.16b, v2.16b
74+
; CHECK-NEXT: shrn v0.4h, v0.4s, #16
75+
; CHECK-NEXT: ret
76+
%tmp1 = load <4 x i64>, ptr %ptr
77+
%tmp2 = uitofp <4 x i64> %tmp1 to <4 x bfloat>
78+
ret <4 x bfloat> %tmp2
79+
}
80+
5781
define <4 x i16> @trunc_v4i64_to_v4i16(ptr %ptr) {
5882
; CHECK-LABEL: trunc_v4i64_to_v4i16:
5983
; CHECK: // %bb.0:

0 commit comments

Comments
 (0)