Skip to content

Commit d3e71ba

Browse files
committed
[AMDGPU] Adapt new lowering sequence for fdiv16
The current lowering of fdiv16 can generate incorrectly rounded result in some cases. Fixes SWDEV-47760.
1 parent 494e9fa commit d3e71ba

File tree

8 files changed

+3507
-1032
lines changed

8 files changed

+3507
-1032
lines changed

llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4903,16 +4903,40 @@ bool AMDGPULegalizerInfo::legalizeFDIV16(MachineInstr &MI,
49034903
LLT S16 = LLT::scalar(16);
49044904
LLT S32 = LLT::scalar(32);
49054905

4906+
// a32.u = opx(V_CVT_F32_F16, a.u); // CVT to F32
4907+
// b32.u = opx(V_CVT_F32_F16, b.u); // CVT to F32
4908+
// r32.u = opx(V_RCP_F32, b32.u); // rcp = 1 / d
4909+
// q32.u = opx(V_MUL_F32, a32.u, r32.u); // q = n * rcp
4910+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u); // err = -d * q + n
4911+
// q32.u = opx(V_MAD_F32, e32.u, r32.u, q32.u); // q = n * rcp
4912+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u); // err = -d * q + n
4913+
// tmp.u = opx(V_MUL_F32, e32.u, r32.u);
4914+
// tmp.u = opx(V_AND_B32, tmp.u, 0xff800000)
4915+
// q32.u = opx(V_ADD_F32, tmp.u, q32.u);
4916+
// q16.u = opx(V_CVT_F16_F32, q32.u);
4917+
// q16.u = opx(V_DIV_FIXUP_F16, q16.u, b.u, a.u); // q = touchup(q, d, n)
4918+
49064919
auto LHSExt = B.buildFPExt(S32, LHS, Flags);
49074920
auto RHSExt = B.buildFPExt(S32, RHS, Flags);
4908-
4909-
auto RCP = B.buildIntrinsic(Intrinsic::amdgcn_rcp, {S32})
4921+
auto NegRHSExt = B.buildFNeg(S32, RHSExt);
4922+
auto Rcp = B.buildIntrinsic(Intrinsic::amdgcn_rcp, {S32})
49104923
.addUse(RHSExt.getReg(0))
49114924
.setMIFlags(Flags);
4912-
4913-
auto QUOT = B.buildFMul(S32, LHSExt, RCP, Flags);
4914-
auto RDst = B.buildFPTrunc(S16, QUOT, Flags);
4915-
4925+
auto Quot = B.buildFMul(S32, LHSExt, Rcp, Flags);
4926+
MachineInstrBuilder Err;
4927+
if (ST.hasMadMacF32Insts()) {
4928+
Err = B.buildFMAD(S32, NegRHSExt, Quot, LHSExt, Flags);
4929+
Quot = B.buildFMAD(S32, Err, Rcp, Quot, Flags);
4930+
Err = B.buildFMAD(S32, NegRHSExt, Quot, LHSExt, Flags);
4931+
} else {
4932+
Err = B.buildFMA(S32, NegRHSExt, Quot, LHSExt, Flags);
4933+
Quot = B.buildFMA(S32, Err, Rcp, Quot, Flags);
4934+
Err = B.buildFMA(S32, NegRHSExt, Quot, LHSExt, Flags);
4935+
}
4936+
auto Tmp = B.buildFMul(S32, Err, Rcp, Flags);
4937+
Tmp = B.buildAnd(S32, Tmp, B.buildConstant(S32, 0xff800000));
4938+
Quot = B.buildFAdd(S32, Tmp, Quot, Flags);
4939+
auto RDst = B.buildFPTrunc(S16, Quot, Flags);
49164940
B.buildIntrinsic(Intrinsic::amdgcn_div_fixup, Res)
49174941
.addUse(RDst.getReg(0))
49184942
.addUse(RHS)

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10619,19 +10619,48 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
1061910619
return FastLowered;
1062010620

1062110621
SDLoc SL(Op);
10622-
SDValue Src0 = Op.getOperand(0);
10623-
SDValue Src1 = Op.getOperand(1);
10624-
10625-
SDValue CvtSrc0 = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, Src0);
10626-
SDValue CvtSrc1 = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, Src1);
10627-
10628-
SDValue RcpSrc1 = DAG.getNode(AMDGPUISD::RCP, SL, MVT::f32, CvtSrc1);
10629-
SDValue Quot = DAG.getNode(ISD::FMUL, SL, MVT::f32, CvtSrc0, RcpSrc1);
10630-
10631-
SDValue FPRoundFlag = DAG.getTargetConstant(0, SL, MVT::i32);
10632-
SDValue BestQuot = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot, FPRoundFlag);
10622+
SDValue LHS = Op.getOperand(0);
10623+
SDValue RHS = Op.getOperand(1);
1063310624

10634-
return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, BestQuot, Src1, Src0);
10625+
// a32.u = opx(V_CVT_F32_F16, a.u); // CVT to F32
10626+
// b32.u = opx(V_CVT_F32_F16, b.u); // CVT to F32
10627+
// r32.u = opx(V_RCP_F32, b32.u); // rcp = 1 / d
10628+
// q32.u = opx(V_MUL_F32, a32.u, r32.u); // q = n * rcp
10629+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u); // err = -d * q + n
10630+
// q32.u = opx(V_MAD_F32, e32.u, r32.u, q32.u); // q = n * rcp
10631+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u); // err = -d * q + n
10632+
// tmp.u = opx(V_MUL_F32, e32.u, r32.u);
10633+
// tmp.u = opx(V_AND_B32, tmp.u, 0xff800000)
10634+
// q32.u = opx(V_ADD_F32, tmp.u, q32.u);
10635+
// q16.u = opx(V_CVT_F16_F32, q32.u);
10636+
// q16.u = opx(V_DIV_FIXUP_F16, q16.u, b.u, a.u); // q = touchup(q, d, n)
10637+
10638+
// We will use ISD::FMA on targets that don't support ISD::FMAD.
10639+
unsigned FMADOpCode =
10640+
isOperationLegal(ISD::FMAD, MVT::f32) ? ISD::FMAD : ISD::FMA;
10641+
10642+
SDValue LHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, LHS);
10643+
SDValue RHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, RHS);
10644+
SDValue NegRHSExt = DAG.getNode(ISD::FNEG, SL, MVT::f32, RHSExt);
10645+
SDValue Rcp =
10646+
DAG.getNode(AMDGPUISD::RCP, SL, MVT::f32, RHSExt, Op->getFlags());
10647+
SDValue Quot =
10648+
DAG.getNode(ISD::FMUL, SL, MVT::f32, LHSExt, Rcp, Op->getFlags());
10649+
SDValue Err = DAG.getNode(FMADOpCode, SL, MVT::f32, NegRHSExt, Quot, LHSExt,
10650+
Op->getFlags());
10651+
Quot = DAG.getNode(FMADOpCode, SL, MVT::f32, Err, Rcp, Quot, Op->getFlags());
10652+
Err = DAG.getNode(FMADOpCode, SL, MVT::f32, NegRHSExt, Quot, LHSExt,
10653+
Op->getFlags());
10654+
SDValue Tmp = DAG.getNode(ISD::FMUL, SL, MVT::f32, Err, Rcp, Op->getFlags());
10655+
SDValue TmpCast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, Tmp);
10656+
TmpCast = DAG.getNode(ISD::AND, SL, MVT::i32, TmpCast,
10657+
DAG.getConstant(0xff800000, SL, MVT::i32));
10658+
Tmp = DAG.getNode(ISD::BITCAST, SL, MVT::f32, TmpCast);
10659+
Quot = DAG.getNode(ISD::FADD, SL, MVT::f32, Tmp, Quot, Op->getFlags());
10660+
SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot,
10661+
DAG.getConstant(0, SL, MVT::i32));
10662+
return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, RDst, RHS, LHS,
10663+
Op->getFlags());
1063510664
}
1063610665

1063710666
// Faster 2.5 ULP division that does not support denormals.

0 commit comments

Comments
 (0)