Skip to content

[SelectionDAG] Require last operand of (STRICT_)FP_ROUND to be a TargetConstant. #117639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5277,7 +5277,7 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp1 = DAG.getNode(TruncOp, dl, Node->getValueType(0), Tmp1);
else
Tmp1 = DAG.getNode(TruncOp, dl, Node->getValueType(0), Tmp1,
DAG.getIntPtrConstant(0, dl));
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should stop using getIntPtrConstant and just make it a fixed i32 (then you could just use getTargetConstant everywhere)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be MVT::i1?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or that, yes. Really this should be an "exact" flag on fptrunc, but for some reason only SelectionDAG has this custom "fast math flag"

Results.push_back(Tmp1);
break;
}
Expand Down Expand Up @@ -5425,7 +5425,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp1 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
{Tmp3, Tmp1, Tmp2});
Tmp1 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
{Tmp1.getValue(1), Tmp1, DAG.getIntPtrConstant(0, dl)});
{Tmp1.getValue(1), Tmp1,
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
Results.push_back(Tmp1);
Results.push_back(Tmp1.getValue(1));
break;
Expand All @@ -5450,7 +5451,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp4 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
{Tmp4, Tmp1, Tmp2, Tmp3});
Tmp4 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
{Tmp4.getValue(1), Tmp4, DAG.getIntPtrConstant(0, dl)});
{Tmp4.getValue(1), Tmp4,
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
Results.push_back(Tmp4);
Results.push_back(Tmp4.getValue(1));
break;
Expand Down Expand Up @@ -5478,7 +5480,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp2 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
{Tmp1.getValue(1), Tmp1, Node->getOperand(2)});
Tmp3 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
{Tmp2.getValue(1), Tmp2, DAG.getIntPtrConstant(0, dl)});
{Tmp2.getValue(1), Tmp2,
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
Results.push_back(Tmp3);
Results.push_back(Tmp3.getValue(1));
break;
Expand Down Expand Up @@ -5562,7 +5565,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp2 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
{Tmp1.getValue(1), Tmp1});
Tmp3 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
{Tmp2.getValue(1), Tmp2, DAG.getIntPtrConstant(0, dl)});
{Tmp2.getValue(1), Tmp2,
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
Results.push_back(Tmp3);
Results.push_back(Tmp3.getValue(1));
break;
Expand Down
13 changes: 6 additions & 7 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,7 @@ SelectionDAG::getStrictFPExtendOrRound(SDValue Op, SDValue Chain,
VT.bitsGT(Op.getValueType())
? getNode(ISD::STRICT_FP_EXTEND, DL, {VT, MVT::Other}, {Chain, Op})
: getNode(ISD::STRICT_FP_ROUND, DL, {VT, MVT::Other},
{Chain, Op, getIntPtrConstant(0, DL)});
{Chain, Op, getIntPtrConstant(0, DL, /*isTarget=*/true)});

return std::pair<SDValue, SDValue>(Res, SDValue(Res.getNode(), 1));
}
Expand Down Expand Up @@ -7355,11 +7355,10 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
return N1;
break;
case ISD::FP_ROUND:
assert(VT.isFloatingPoint() &&
N1.getValueType().isFloatingPoint() &&
VT.bitsLE(N1.getValueType()) &&
N2C && (N2C->getZExtValue() == 0 || N2C->getZExtValue() == 1) &&
"Invalid FP_ROUND!");
assert(VT.isFloatingPoint() && N1.getValueType().isFloatingPoint() &&
VT.bitsLE(N1.getValueType()) && N2C &&
(N2C->getZExtValue() == 0 || N2C->getZExtValue() == 1) &&
N2.getOpcode() == ISD::TargetConstant && "Invalid FP_ROUND!");
if (N1.getValueType() == VT) return N1; // noop conversion.
break;
case ISD::AssertSext:
Expand Down Expand Up @@ -10542,7 +10541,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
assert(VTList.VTs[0].isFloatingPoint() &&
Ops[1].getValueType().isFloatingPoint() &&
VTList.VTs[0].bitsLT(Ops[1].getValueType()) &&
isa<ConstantSDNode>(Ops[2]) &&
Ops[2].getOpcode() == ISD::TargetConstant &&
(Ops[2]->getAsZExtVal() == 0 || Ops[2]->getAsZExtVal() == 1) &&
"Invalid STRICT_FP_ROUND!");
break;
Expand Down
37 changes: 20 additions & 17 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4901,13 +4901,14 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
if (IsStrict) {
SDValue Val = DAG.getNode(Op.getOpcode(), dl, {F32, MVT::Other},
{Op.getOperand(0), In});
return DAG.getNode(
ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
{Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
return DAG.getNode(ISD::STRICT_FP_ROUND, dl,
{Op.getValueType(), MVT::Other},
{Val.getValue(1), Val.getValue(0),
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
}
return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
DAG.getNode(Op.getOpcode(), dl, F32, In),
DAG.getIntPtrConstant(0, dl));
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
}

uint64_t VTSize = VT.getFixedSizeInBits();
Expand All @@ -4919,9 +4920,9 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
if (IsStrict) {
In = DAG.getNode(Opc, dl, {CastVT, MVT::Other},
{Op.getOperand(0), In});
return DAG.getNode(
ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
{In.getValue(1), In.getValue(0), DAG.getIntPtrConstant(0, dl)});
return DAG.getNode(ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
{In.getValue(1), In.getValue(0),
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
}
In = DAG.getNode(Opc, dl, CastVT, In);
return DAG.getNode(ISD::FP_ROUND, dl, VT, In,
Expand Down Expand Up @@ -4969,13 +4970,14 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
if (IsStrict) {
SDValue Val = DAG.getNode(Op.getOpcode(), dl, {PromoteVT, MVT::Other},
{Op.getOperand(0), SrcVal});
return DAG.getNode(
ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
{Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
return DAG.getNode(ISD::STRICT_FP_ROUND, dl,
{Op.getValueType(), MVT::Other},
{Val.getValue(1), Val.getValue(0),
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
}
return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
DAG.getNode(Op.getOpcode(), dl, PromoteVT, SrcVal),
DAG.getIntPtrConstant(0, dl));
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
};

if (Op.getValueType() == MVT::bf16) {
Expand Down Expand Up @@ -5067,12 +5069,13 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
DAG.getNode(ISD::OR, DL, MVT::i64, RoundedBits, NeedsAdjustment);
SDValue Adjusted = DAG.getNode(ISD::BITCAST, DL, MVT::f64, AdjustedBits);
return IsStrict
? DAG.getNode(ISD::STRICT_FP_ROUND, DL,
{Op.getValueType(), MVT::Other},
{Rounded.getValue(1), Adjusted,
DAG.getIntPtrConstant(0, DL)})
? DAG.getNode(
ISD::STRICT_FP_ROUND, DL,
{Op.getValueType(), MVT::Other},
{Rounded.getValue(1), Adjusted,
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)})
: DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), Adjusted,
DAG.getIntPtrConstant(0, DL, true));
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
}
}

Expand Down Expand Up @@ -7109,7 +7112,7 @@ static SDValue LowerFLDEXP(SDValue Op, SelectionDAG &DAG) {
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, X.getValueType(), FScale, Zero);
if (X.getValueType() != XScalarTy)
Final = DAG.getNode(ISD::FP_ROUND, DL, XScalarTy, Final,
DAG.getIntPtrConstant(1, SDLoc(Op)));
DAG.getIntPtrConstant(1, SDLoc(Op), /*isTarget=*/true));
return Final;
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10756,7 +10756,7 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
Tmp = DAG.getNode(ISD::BITCAST, SL, MVT::f32, TmpCast);
Quot = DAG.getNode(ISD::FADD, SL, MVT::f32, Tmp, Quot, Op->getFlags());
SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot,
DAG.getConstant(0, SL, MVT::i32));
DAG.getTargetConstant(0, SL, MVT::i32));
return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, RDst, RHS, LHS,
Op->getFlags());
}
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1575,9 +1575,10 @@ HexagonTargetLowering::resizeToWidth(SDValue VecV, MVT ResTy, bool Signed,
unsigned ResWidth = ResTy.getSizeInBits();

if (InpTy.isFloatingPoint()) {
return InpWidth < ResWidth ? DAG.getNode(ISD::FP_EXTEND, dl, ResTy, VecV)
: DAG.getNode(ISD::FP_ROUND, dl, ResTy, VecV,
getZero(dl, MVT::i32, DAG));
return InpWidth < ResWidth
? DAG.getNode(ISD::FP_EXTEND, dl, ResTy, VecV)
: DAG.getNode(ISD::FP_ROUND, dl, ResTy, VecV,
DAG.getTargetConstant(0, dl, MVT::i32));
}

assert(InpTy.isInteger());
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2786,7 +2786,7 @@ SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
return DAG.getNode(
ISD::FP_ROUND, Loc, MVT::bf16,
DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
DAG.getIntPtrConstant(0, Loc));
DAG.getIntPtrConstant(0, Loc, /*isTarget=*/true));
}

// Everything else is considered legal.
Expand Down
13 changes: 7 additions & 6 deletions llvm/lib/Target/PowerPC/PPCISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8963,9 +8963,10 @@ SDValue PPCTargetLowering::LowerINT_TO_FP(SDValue Op,

if (Op.getValueType() == MVT::f32 && !Subtarget.hasFPCVT()) {
if (IsStrict)
FP = DAG.getNode(ISD::STRICT_FP_ROUND, dl,
DAG.getVTList(MVT::f32, MVT::Other),
{Chain, FP, DAG.getIntPtrConstant(0, dl)}, Flags);
FP = DAG.getNode(
ISD::STRICT_FP_ROUND, dl, DAG.getVTList(MVT::f32, MVT::Other),
{Chain, FP, DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)},
Flags);
else
FP = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, FP,
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
Expand Down Expand Up @@ -9044,9 +9045,9 @@ SDValue PPCTargetLowering::LowerINT_TO_FP(SDValue Op,
Chain = FP.getValue(1);
if (Op.getValueType() == MVT::f32 && !Subtarget.hasFPCVT()) {
if (IsStrict)
FP = DAG.getNode(ISD::STRICT_FP_ROUND, dl,
DAG.getVTList(MVT::f32, MVT::Other),
{Chain, FP, DAG.getIntPtrConstant(0, dl)}, Flags);
FP = DAG.getNode(
ISD::STRICT_FP_ROUND, dl, DAG.getVTList(MVT::f32, MVT::Other),
{Chain, FP, DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)}, Flags);
else
FP = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, FP,
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19595,7 +19595,7 @@ static SDValue promoteXINT_TO_FP(SDValue Op, const SDLoc &dl,
MVT VT = Op.getSimpleValueType();
MVT NVT = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;

SDValue Rnd = DAG.getIntPtrConstant(0, dl);
SDValue Rnd = DAG.getIntPtrConstant(0, dl, /*isTarget=*/true);
if (IsStrict)
return DAG.getNode(
ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
Expand Down Expand Up @@ -20266,7 +20266,8 @@ SDValue X86TargetLowering::LowerUINT_TO_FP(SDValue Op,
if (DstVT == MVT::f80)
return Add;
return DAG.getNode(ISD::STRICT_FP_ROUND, dl, {DstVT, MVT::Other},
{Add.getValue(1), Add, DAG.getIntPtrConstant(0, dl)});
{Add.getValue(1), Add,
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
}
unsigned Opc = ISD::FADD;
// Windows needs the precision control changed to 80bits around this add.
Expand Down
Loading