Skip to content

Commit ffc3e80

Browse files
author
QingShan Zhang
committed
[NFC] [DAGCombine] Correct the result for sqrt even the iteration is zero
For now, we correct the result for sqrt if iteration > 0. This doesn't make sense as they are not strict relative. Reviewed By: dmgreen, spatel, RKSimon Differential Revision: https://reviews.llvm.org/D94480
1 parent 89a5147 commit ffc3e80

File tree

6 files changed

+57
-48
lines changed

6 files changed

+57
-48
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4287,9 +4287,7 @@ class TargetLowering : public TargetLoweringBase {
42874287
/// comparison may check if the operand is NAN, INF, zero, normal, etc. The
42884288
/// result should be used as the condition operand for a select or branch.
42894289
virtual SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
4290-
const DenormalMode &Mode) const {
4291-
return SDValue();
4292-
}
4290+
const DenormalMode &Mode) const;
42934291

42944292
/// Return a target-dependent result if the input operand is not suitable for
42954293
/// use with a square root estimate calculation.

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22275,43 +22275,21 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
2227522275
Reciprocal)) {
2227622276
AddToWorklist(Est.getNode());
2227722277

22278-
if (Iterations) {
22278+
if (Iterations)
2227922279
Est = UseOneConstNR
2228022280
? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
2228122281
: buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
22282-
22283-
if (!Reciprocal) {
22284-
SDLoc DL(Op);
22285-
EVT CCVT = getSetCCResultType(VT);
22286-
SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
22287-
DenormalMode DenormMode = DAG.getDenormalMode(VT);
22288-
// Try the target specific test first.
22289-
SDValue Test = TLI.getSqrtInputTest(Op, DAG, DenormMode);
22290-
if (!Test) {
22291-
// If no test provided by target, testing it with denormal inputs to
22292-
// avoid wrong estimate.
22293-
if (DenormMode.Input == DenormalMode::IEEE) {
22294-
// This is specifically a check for the handling of denormal inputs,
22295-
// not the result.
22296-
22297-
// Test = fabs(X) < SmallestNormal
22298-
const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
22299-
APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
22300-
SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
22301-
SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
22302-
Test = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
22303-
} else
22304-
// Test = X == 0.0
22305-
Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
22306-
}
22307-
22308-
// The estimate is now completely wrong if the input was exactly 0.0 or
22309-
// possibly a denormal. Force the answer to 0.0 or value provided by
22310-
// target for those cases.
22311-
Est = DAG.getNode(
22312-
Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
22313-
Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
22314-
}
22282+
if (!Reciprocal) {
22283+
SDLoc DL(Op);
22284+
// Try the target specific test first.
22285+
SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
22286+
22287+
// The estimate is now completely wrong if the input was exactly 0.0 or
22288+
// possibly a denormal. Force the answer to 0.0 or value provided by
22289+
// target for those cases.
22290+
Est = DAG.getNode(
22291+
Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
22292+
Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
2231522293
}
2231622294
return Est;
2231722295
}

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5841,6 +5841,28 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
58415841
return false;
58425842
}
58435843

5844+
SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
5845+
const DenormalMode &Mode) const {
5846+
SDLoc DL(Op);
5847+
EVT VT = Op.getValueType();
5848+
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
5849+
SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
5850+
// Testing it with denormal inputs to avoid wrong estimate.
5851+
if (Mode.Input == DenormalMode::IEEE) {
5852+
// This is specifically a check for the handling of denormal inputs,
5853+
// not the result.
5854+
5855+
// Test = fabs(X) < SmallestNormal
5856+
const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
5857+
APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
5858+
SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
5859+
SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
5860+
return DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
5861+
}
5862+
// Test = X == 0.0
5863+
return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
5864+
}
5865+
58445866
SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
58455867
bool LegalOps, bool OptForSize,
58465868
NegatibleCost &Cost,

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7471,6 +7471,22 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
74717471
return SDValue();
74727472
}
74737473

7474+
SDValue
7475+
AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
7476+
const DenormalMode &Mode) const {
7477+
SDLoc DL(Op);
7478+
EVT VT = Op.getValueType();
7479+
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
7480+
SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
7481+
return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
7482+
}
7483+
7484+
SDValue
7485+
AArch64TargetLowering::getSqrtResultForDenormInput(SDValue Op,
7486+
SelectionDAG &DAG) const {
7487+
return Op;
7488+
}
7489+
74747490
SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
74757491
SelectionDAG &DAG, int Enabled,
74767492
int &ExtraSteps,
@@ -7494,17 +7510,8 @@ SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
74947510
Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags);
74957511
Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
74967512
}
7497-
if (!Reciprocal) {
7498-
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
7499-
VT);
7500-
SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
7501-
SDValue Eq = DAG.getSetCC(DL, CCVT, Operand, FPZero, ISD::SETEQ);
7502-
7513+
if (!Reciprocal)
75037514
Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags);
7504-
// Correct the result if the operand is 0.0.
7505-
Estimate = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL,
7506-
VT, Eq, Operand, Estimate);
7507-
}
75087515

75097516
ExtraSteps = 0;
75107517
return Estimate;

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,10 @@ class AArch64TargetLowering : public TargetLowering {
961961
bool Reciprocal) const override;
962962
SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
963963
int &ExtraSteps) const override;
964+
SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
965+
const DenormalMode &Mode) const override;
966+
SDValue getSqrtResultForDenormInput(SDValue Operand,
967+
SelectionDAG &DAG) const override;
964968
unsigned combineRepeatedFPDivisors() const override;
965969

966970
ConstraintType getConstraintType(StringRef Constraint) const override;

llvm/lib/Target/PowerPC/PPCISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12133,7 +12133,7 @@ SDValue PPCTargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
1213312133
if (!isTypeLegal(MVT::i1) ||
1213412134
(VT != MVT::f64 &&
1213512135
((VT != MVT::v2f64 && VT != MVT::v4f32) || !Subtarget.hasVSX())))
12136-
return SDValue();
12136+
return TargetLowering::getSqrtInputTest(Op, DAG, Mode);
1213712137

1213812138
SDLoc DL(Op);
1213912139
// The output register of FTSQRT is CR field.

0 commit comments

Comments
 (0)