@@ -4121,14 +4121,16 @@ SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
4121
4121
4122
4122
// Now that we have rounded, shift the bits into position.
4123
4123
Narrow = DAG.getNode(ISD::SRL, dl, I32, Narrow,
4124
- DAG.getShiftAmountConstant(16, I32, dl));
4124
+ DAG.getShiftAmountConstant(16, I32, dl));
4125
4125
if (VT.isVector()) {
4126
4126
EVT I16 = I32.changeVectorElementType(MVT::i16);
4127
4127
Narrow = DAG.getNode(ISD::TRUNCATE, dl, I16, Narrow);
4128
4128
return DAG.getNode(ISD::BITCAST, dl, VT, Narrow);
4129
4129
}
4130
4130
Narrow = DAG.getNode(ISD::BITCAST, dl, F32, Narrow);
4131
- return DAG.getTargetExtractSubreg(AArch64::hsub, dl, VT, Narrow);
4131
+ SDValue Result = DAG.getTargetExtractSubreg(AArch64::hsub, dl, VT, Narrow);
4132
+ return IsStrict ? DAG.getMergeValues({Result, Op.getOperand(0)}, dl)
4133
+ : Result;
4132
4134
}
4133
4135
4134
4136
if (SrcVT != MVT::f128) {
@@ -4487,20 +4489,121 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
4487
4489
bool IsStrict = Op->isStrictFPOpcode();
4488
4490
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
4489
4491
4490
- // f16 conversions are promoted to f32 when full fp16 is not supported.
4491
- if ((Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) || Op.getValueType() == MVT::bf16) {
4492
+ bool IsSigned = Op->getOpcode() == ISD::STRICT_SINT_TO_FP ||
4493
+ Op->getOpcode() == ISD::SINT_TO_FP;
4494
+
4495
+ auto IntToFpViaPromotion = [&](EVT PromoteVT) {
4492
4496
SDLoc dl(Op);
4493
4497
if (IsStrict) {
4494
- SDValue Val = DAG.getNode(Op.getOpcode(), dl, {MVT::f32 , MVT::Other},
4498
+ SDValue Val = DAG.getNode(Op.getOpcode(), dl, {PromoteVT , MVT::Other},
4495
4499
{Op.getOperand(0), SrcVal});
4496
4500
return DAG.getNode(
4497
4501
ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
4498
4502
{Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
4499
4503
}
4500
- return DAG.getNode(
4501
- ISD::FP_ROUND, dl, Op.getValueType(),
4502
- DAG.getNode(Op.getOpcode(), dl, MVT::f32, SrcVal),
4503
- DAG.getIntPtrConstant(0, dl));
4504
+ return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
4505
+ DAG.getNode(Op.getOpcode(), dl, PromoteVT, SrcVal),
4506
+ DAG.getIntPtrConstant(0, dl));
4507
+ };
4508
+
4509
+ if (Op.getValueType() == MVT::bf16) {
4510
+ // bf16 conversions are promoted to f32 when converting from i16.
4511
+ if (DAG.ComputeMaxSignificantBits(SrcVal) <= 24) {
4512
+ return IntToFpViaPromotion(MVT::f32);
4513
+ }
4514
+
4515
+ // bf16 conversions are promoted to f64 when converting from i32.
4516
+ if (DAG.ComputeMaxSignificantBits(SrcVal) <= 53) {
4517
+ return IntToFpViaPromotion(MVT::f64);
4518
+ }
4519
+
4520
+ // We need to be careful about i64 -> bf16.
4521
+ // Consider an i32 22216703.
4522
+ // This number cannot be represented exactly as an f32 and so a itofp will
4523
+ // turn it into 22216704.0 fptrunc to bf16 will turn this into 22282240.0
4524
+ // However, the correct bf16 was supposed to be 22151168.0
4525
+ // We need to use sticky rounding to get this correct.
4526
+ if (SrcVal.getValueType() == MVT::i64) {
4527
+ SDLoc DL(Op);
4528
+ // This algorithm is equivalent to the following:
4529
+ // uint64_t SrcHi = SrcVal & ~0xfffull;
4530
+ // uint64_t SrcLo = SrcVal & 0xfffull;
4531
+ // uint64_t Highest = SrcVal >> 53;
4532
+ // bool HasHighest = Highest != 0;
4533
+ // uint64_t ToRound = HasHighest ? SrcHi : SrcVal;
4534
+ // double Rounded = static_cast<double>(ToRound);
4535
+ // uint64_t RoundedBits = std::bit_cast<uint64_t>(Rounded);
4536
+ // uint64_t HasLo = SrcLo != 0;
4537
+ // bool NeedsAdjustment = HasHighest & HasLo;
4538
+ // uint64_t AdjustedBits = RoundedBits | uint64_t{NeedsAdjustment};
4539
+ // double Adjusted = std::bit_cast<double>(AdjustedBits);
4540
+ // return static_cast<__bf16>(Adjusted);
4541
+ //
4542
+ // Essentially, what happens is that SrcVal either fits perfectly in a
4543
+ // double-precision value or it is too big. If it is sufficiently small,
4544
+ // we should just go u64 -> double -> bf16 in a naive way. Otherwise, we
4545
+ // ensure that u64 -> double has no rounding error by only using the 52
4546
+ // MSB of the input. The low order bits will get merged into a sticky bit
4547
+ // which will avoid issues incurred by double rounding.
4548
+
4549
+ // Signed conversion is more or less like so:
4550
+ // copysign((__bf16)abs(SrcVal), SrcVal)
4551
+ SDValue SignBit;
4552
+ if (IsSigned) {
4553
+ SignBit = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4554
+ DAG.getConstant(1ull << 63, DL, MVT::i64));
4555
+ SrcVal = DAG.getNode(ISD::ABS, DL, MVT::i64, SrcVal);
4556
+ }
4557
+ SDValue SrcHi = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4558
+ DAG.getConstant(~0xfffull, DL, MVT::i64));
4559
+ SDValue SrcLo = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4560
+ DAG.getConstant(0xfffull, DL, MVT::i64));
4561
+ SDValue Highest =
4562
+ DAG.getNode(ISD::SRL, DL, MVT::i64, SrcVal,
4563
+ DAG.getShiftAmountConstant(53, MVT::i64, DL));
4564
+ SDValue Zero64 = DAG.getConstant(0, DL, MVT::i64);
4565
+ SDValue ToRound =
4566
+ DAG.getSelectCC(DL, Highest, Zero64, SrcHi, SrcVal, ISD::SETNE);
4567
+ SDValue Rounded =
4568
+ IsStrict ? DAG.getNode(Op.getOpcode(), DL, {MVT::f64, MVT::Other},
4569
+ {Op.getOperand(0), ToRound})
4570
+ : DAG.getNode(Op.getOpcode(), DL, MVT::f64, ToRound);
4571
+
4572
+ SDValue RoundedBits = DAG.getNode(ISD::BITCAST, DL, MVT::i64, Rounded);
4573
+ if (SignBit) {
4574
+ RoundedBits = DAG.getNode(ISD::OR, DL, MVT::i64, RoundedBits, SignBit);
4575
+ }
4576
+
4577
+ SDValue HasHighest = DAG.getSetCC(
4578
+ DL,
4579
+ getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i64),
4580
+ Highest, Zero64, ISD::SETNE);
4581
+
4582
+ SDValue HasLo = DAG.getSetCC(
4583
+ DL,
4584
+ getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i64),
4585
+ SrcLo, Zero64, ISD::SETNE);
4586
+
4587
+ SDValue NeedsAdjustment =
4588
+ DAG.getNode(ISD::AND, DL, HasLo.getValueType(), HasHighest, HasLo);
4589
+ NeedsAdjustment = DAG.getZExtOrTrunc(NeedsAdjustment, DL, MVT::i64);
4590
+
4591
+ SDValue AdjustedBits =
4592
+ DAG.getNode(ISD::OR, DL, MVT::i64, RoundedBits, NeedsAdjustment);
4593
+ SDValue Adjusted = DAG.getNode(ISD::BITCAST, DL, MVT::f64, AdjustedBits);
4594
+ return IsStrict
4595
+ ? DAG.getNode(ISD::STRICT_FP_ROUND, DL,
4596
+ {Op.getValueType(), MVT::Other},
4597
+ {Rounded.getValue(1), Adjusted,
4598
+ DAG.getIntPtrConstant(0, DL)})
4599
+ : DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), Adjusted,
4600
+ DAG.getIntPtrConstant(0, DL, true));
4601
+ }
4602
+ }
4603
+
4604
+ // f16 conversions are promoted to f32 when full fp16 is not supported.
4605
+ if (Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
4606
+ return IntToFpViaPromotion(MVT::f32);
4504
4607
}
4505
4608
4506
4609
// i128 conversions are libcalls.
0 commit comments