Skip to content

Commit 03ada68

Browse files
committed
[AMDGPU] Adapt new lowering sequence for fdiv16
The current lowering of fdiv16 can generate incorrectly rounded result in some cases. Fixes SWDEV-47760.
1 parent eaedbbc commit 03ada68

File tree

8 files changed

+2088
-901
lines changed

8 files changed

+2088
-901
lines changed

llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4900,16 +4900,33 @@ bool AMDGPULegalizerInfo::legalizeFDIV16(MachineInstr &MI,
49004900
LLT S16 = LLT::scalar(16);
49014901
LLT S32 = LLT::scalar(32);
49024902

4903+
// a32.u = opx(V_CVT_F32_F16, a.u);
4904+
// b32.u = opx(V_CVT_F32_F16, b.u);
4905+
// r32.u = opx(V_RCP_F32, b32.u);
4906+
// q32.u = opx(V_MUL_F32, a32.u, r32.u);
4907+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u);
4908+
// q32.u = opx(V_MAD_F32, e32.u, r32.u, q32.u);
4909+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u);
4910+
// tmp.u = opx(V_MUL_F32, e32.u, r32.u);
4911+
// tmp.u = opx(V_AND_B32, tmp.u, 0xff800000);
4912+
// q32.u = opx(V_ADD_F32, tmp.u, q32.u);
4913+
// q16.u = opx(V_CVT_F16_F32, q32.u);
4914+
// q16.u = opx(V_DIV_FIXUP_F16, q16.u, b.u, a.u);
4915+
49034916
auto LHSExt = B.buildFPExt(S32, LHS, Flags);
49044917
auto RHSExt = B.buildFPExt(S32, RHS, Flags);
4905-
4906-
auto RCP = B.buildIntrinsic(Intrinsic::amdgcn_rcp, {S32})
4918+
auto NegRHSExt = B.buildFNeg(S32, RHSExt);
4919+
auto Rcp = B.buildIntrinsic(Intrinsic::amdgcn_rcp, {S32})
49074920
.addUse(RHSExt.getReg(0))
49084921
.setMIFlags(Flags);
4909-
4910-
auto QUOT = B.buildFMul(S32, LHSExt, RCP, Flags);
4911-
auto RDst = B.buildFPTrunc(S16, QUOT, Flags);
4912-
4922+
auto Quot = B.buildFMul(S32, LHSExt, Rcp);
4923+
auto Err = B.buildFMA(S32, NegRHSExt, Quot, LHSExt);
4924+
Quot = B.buildFMA(S32, Err, Rcp, Quot);
4925+
Err = B.buildFMA(S32, NegRHSExt, Quot, LHSExt);
4926+
auto Tmp = B.buildFMul(S32, Err, Rcp);
4927+
Tmp = B.buildAnd(S32, Tmp, B.buildConstant(S32, 0xff800000));
4928+
Quot = B.buildFAdd(S32, Tmp, Quot);
4929+
auto RDst = B.buildFPTrunc(S16, Quot, Flags);
49134930
B.buildIntrinsic(Intrinsic::amdgcn_div_fixup, Res)
49144931
.addUse(RDst.getReg(0))
49154932
.addUse(RHS)

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10606,19 +10606,40 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
1060610606
return FastLowered;
1060710607

1060810608
SDLoc SL(Op);
10609-
SDValue Src0 = Op.getOperand(0);
10610-
SDValue Src1 = Op.getOperand(1);
10611-
10612-
SDValue CvtSrc0 = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, Src0);
10613-
SDValue CvtSrc1 = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, Src1);
10614-
10615-
SDValue RcpSrc1 = DAG.getNode(AMDGPUISD::RCP, SL, MVT::f32, CvtSrc1);
10616-
SDValue Quot = DAG.getNode(ISD::FMUL, SL, MVT::f32, CvtSrc0, RcpSrc1);
10617-
10618-
SDValue FPRoundFlag = DAG.getTargetConstant(0, SL, MVT::i32);
10619-
SDValue BestQuot = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot, FPRoundFlag);
10609+
SDValue LHS = Op.getOperand(0);
10610+
SDValue RHS = Op.getOperand(1);
1062010611

10621-
return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, BestQuot, Src1, Src0);
10612+
// a32.u = opx(V_CVT_F32_F16, a.u);
10613+
// b32.u = opx(V_CVT_F32_F16, b.u);
10614+
// r32.u = opx(V_RCP_F32, b32.u);
10615+
// q32.u = opx(V_MUL_F32, a32.u, r32.u);
10616+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u);
10617+
// q32.u = opx(V_MAD_F32, e32.u, r32.u, q32.u);
10618+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u);
10619+
// tmp.u = opx(V_MUL_F32, e32.u, r32.u);
10620+
// tmp.u = opx(V_AND_B32, tmp.u, 0xff800000)
10621+
// tmp.u = opx(V_FREXP_MANT_F32, tmp.u);
10622+
// q32.u = opx(V_ADD_F32, tmp.u, q32.u);
10623+
// q16.u = opx(V_CVT_F16_F32, q32.u);
10624+
// q16.u = opx(V_DIV_FIXUP_F16, q16.u, b.u, a.u);
10625+
10626+
SDValue LHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, LHS);
10627+
SDValue RHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, RHS);
10628+
SDValue NegRHSExt = DAG.getNode(ISD::FNEG, SL, MVT::f32, RHSExt);
10629+
SDValue Rcp = DAG.getNode(AMDGPUISD::RCP, SL, MVT::f32, RHSExt);
10630+
SDValue Quot = DAG.getNode(ISD::FMUL, SL, MVT::f32, LHSExt, Rcp);
10631+
SDValue Err = DAG.getNode(ISD::FMA, SL, MVT::f32, NegRHSExt, Quot, LHSExt);
10632+
Quot = DAG.getNode(ISD::FMA, SL, MVT::f32, Err, Rcp, Quot);
10633+
Err = DAG.getNode(ISD::FMA, SL, MVT::f32, NegRHSExt, Quot, LHSExt);
10634+
SDValue Tmp = DAG.getNode(ISD::FMUL, SL, MVT::f32, Err, Rcp);
10635+
SDValue TmpCast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, Tmp);
10636+
TmpCast = DAG.getNode(ISD::AND, SL, MVT::i32, TmpCast,
10637+
DAG.getConstant(0xff800000, SL, MVT::i32));
10638+
Tmp = DAG.getNode(ISD::BITCAST, SL, MVT::f32, TmpCast);
10639+
Quot = DAG.getNode(ISD::FADD, SL, MVT::f32, Tmp, Quot);
10640+
SDValue FPRoundFlag = DAG.getConstant(0, SL, MVT::i32);
10641+
SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot, FPRoundFlag);
10642+
return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, RDst, RHS, LHS);
1062210643
}
1062310644

1062410645
// Faster 2.5 ULP division that does not support denormals.

0 commit comments

Comments
 (0)