Skip to content

Scalarize the vector inputs to llvm.lround intrinsic by default. #101054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16819,7 +16819,8 @@ Syntax:
"""""""

This is an overloaded intrinsic. You can use ``llvm.lround`` on any
floating-point type. Not all targets support all types however.
floating-point type or vector of floating-point type. Not all targets
support all types however.

::

Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4921,6 +4921,8 @@ LegalizerHelper::fewerElementsVector(MachineInstr &MI, unsigned TypeIdx,
case G_INTRINSIC_LLRINT:
case G_INTRINSIC_ROUND:
case G_INTRINSIC_ROUNDEVEN:
case G_LROUND:
case G_LLROUND:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lit test is not testing this but it will complete the handling of cases in line with similar intrinsic, LRINT.

case G_INTRINSIC_TRUNC:
case G_FCOS:
case G_FSIN:
Expand Down
15 changes: 14 additions & 1 deletion llvm/lib/CodeGen/MachineVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2062,7 +2062,20 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
}
case TargetOpcode::G_LLROUND:
case TargetOpcode::G_LROUND: {
verifyAllRegOpsScalar(*MI, *MRI);
LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
LLT SrcTy = MRI->getType(MI->getOperand(1).getReg());
if (!DstTy.isValid() || !SrcTy.isValid())
break;
if (SrcTy.isPointer() || DstTy.isPointer()) {
StringRef Op = SrcTy.isPointer() ? "Source" : "Destination";
report(Twine(Op, " operand must not be a pointer type"), MI);
} else if (SrcTy.isScalar()) {
verifyAllRegOpsScalar(*MI, *MRI);
break;
} else if (SrcTy.isVector()) {
verifyVectorElementMatch(SrcTy, DstTy, MI);
break;
}
break;
}
case TargetOpcode::G_IS_FPCLASS: {
Expand Down
10 changes: 7 additions & 3 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ namespace {
SDValue visitUINT_TO_FP(SDNode *N);
SDValue visitFP_TO_SINT(SDNode *N);
SDValue visitFP_TO_UINT(SDNode *N);
SDValue visitXRINT(SDNode *N);
SDValue visitXROUND(SDNode *N);
SDValue visitFP_ROUND(SDNode *N);
SDValue visitFP_EXTEND(SDNode *N);
SDValue visitFNEG(SDNode *N);
Expand Down Expand Up @@ -1929,8 +1929,10 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
case ISD::LROUND:
case ISD::LLROUND:
case ISD::LRINT:
case ISD::LLRINT: return visitXRINT(N);
case ISD::LLRINT: return visitXROUND(N);
case ISD::FP_ROUND: return visitFP_ROUND(N);
case ISD::FP_EXTEND: return visitFP_EXTEND(N);
case ISD::FNEG: return visitFNEG(N);
Expand Down Expand Up @@ -17984,15 +17986,17 @@ SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
return FoldIntToFPToInt(N, DAG);
}

SDValue DAGCombiner::visitXRINT(SDNode *N) {
SDValue DAGCombiner::visitXROUND(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);

// fold (lrint|llrint undef) -> undef
// fold (lround|llround undef) -> undef
if (N0.isUndef())
return DAG.getUNDEF(VT);

// fold (lrint|llrint c1fp) -> c1
// fold (lround|llround c1fp) -> c1
if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0);

Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2441,6 +2441,8 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
case ISD::FCOPYSIGN: R = PromoteFloatOp_FCOPYSIGN(N, OpNo); break;
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
case ISD::LROUND:
case ISD::LLROUND:
case ISD::LRINT:
case ISD::LLRINT: R = PromoteFloatOp_UnaryOp(N, OpNo); break;
case ISD::FP_TO_SINT_SAT:
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecRes_Convert(SDNode *N);
SDValue WidenVecRes_Convert_StrictFP(SDNode *N);
SDValue WidenVecRes_FP_TO_XINT_SAT(SDNode *N);
SDValue WidenVecRes_XRINT(SDNode *N);
SDValue WidenVecRes_XROUND(SDNode *N);
SDValue WidenVecRes_FCOPYSIGN(SDNode *N);
SDValue WidenVecRes_UnarySameEltsWithScalarArg(SDNode *N);
SDValue WidenVecRes_ExpOp(SDNode *N);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
Node->getValueType(0), Scale);
break;
}
case ISD::LROUND:
case ISD::LLROUND:
case ISD::LRINT:
case ISD::LLRINT:
case ISD::SINT_TO_FP:
Expand Down
16 changes: 14 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
case ISD::LLRINT:
case ISD::FROUND:
case ISD::FROUNDEVEN:
case ISD::LROUND:
case ISD::LLROUND:
case ISD::FSIN:
case ISD::FSINH:
case ISD::FSQRT:
Expand Down Expand Up @@ -752,6 +754,8 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::FP_TO_UINT:
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP:
case ISD::LROUND:
case ISD::LLROUND:
case ISD::LRINT:
case ISD::LLRINT:
Res = ScalarizeVecOp_UnaryOp(N);
Expand Down Expand Up @@ -1215,6 +1219,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::VP_FROUND:
case ISD::FROUNDEVEN:
case ISD::VP_FROUNDEVEN:
case ISD::LROUND:
case ISD::LLROUND:
case ISD::FSIN:
case ISD::FSINH:
case ISD::FSQRT: case ISD::VP_SQRT:
Expand Down Expand Up @@ -3270,6 +3276,8 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::ZERO_EXTEND:
case ISD::ANY_EXTEND:
case ISD::FTRUNC:
case ISD::LROUND:
case ISD::LLROUND:
case ISD::LRINT:
case ISD::LLRINT:
Res = SplitVecOp_UnaryOp(N);
Expand Down Expand Up @@ -4594,7 +4602,9 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
case ISD::LLRINT:
case ISD::VP_LRINT:
case ISD::VP_LLRINT:
Res = WidenVecRes_XRINT(N);
case ISD::LROUND:
case ISD::LLROUND:
Res = WidenVecRes_XROUND(N);
break;

case ISD::FABS:
Expand Down Expand Up @@ -5211,7 +5221,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_FP_TO_XINT_SAT(SDNode *N) {
return DAG.getNode(N->getOpcode(), dl, WidenVT, Src, N->getOperand(1));
}

SDValue DAGTypeLegalizer::WidenVecRes_XRINT(SDNode *N) {
SDValue DAGTypeLegalizer::WidenVecRes_XROUND(SDNode *N) {
SDLoc dl(N);
EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
ElementCount WidenNumElts = WidenVT.getVectorElementCount();
Expand Down Expand Up @@ -6460,6 +6470,8 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VSELECT: Res = WidenVecOp_VSELECT(N); break;
case ISD::FLDEXP:
case ISD::FCOPYSIGN:
case ISD::LROUND:
case ISD::LLROUND:
case ISD::LRINT:
case ISD::LLRINT:
Res = WidenVecOp_UnrollVectorOp(N);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5436,6 +5436,8 @@ bool SelectionDAG::isKnownNeverNaN(SDValue Op, bool SNaN, unsigned Depth) const
case ISD::FCEIL:
case ISD::FROUND:
case ISD::FROUNDEVEN:
case ISD::LROUND:
case ISD::LLROUND:
case ISD::FRINT:
case ISD::LRINT:
case ISD::LLRINT:
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,8 +774,9 @@ void TargetLoweringBase::initActions() {
setOperationAction(
{ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG, ISD::ANY_EXTEND_VECTOR_INREG,
ISD::SIGN_EXTEND_VECTOR_INREG, ISD::ZERO_EXTEND_VECTOR_INREG,
ISD::SPLAT_VECTOR, ISD::LRINT, ISD::LLRINT, ISD::FTAN, ISD::FACOS,
ISD::FASIN, ISD::FATAN, ISD::FCOSH, ISD::FSINH, ISD::FTANH},
ISD::SPLAT_VECTOR, ISD::LRINT, ISD::LLRINT, ISD::LROUND,
ISD::LLROUND, ISD::FTAN, ISD::FACOS, ISD::FASIN, ISD::FATAN,
ISD::FCOSH, ISD::FSINH, ISD::FTANH},
VT, Expand);

// Constrained floating-point operations default to expand.
Expand Down
18 changes: 16 additions & 2 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5975,8 +5975,22 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
case Intrinsic::llround: {
Type *ValTy = Call.getArgOperand(0)->getType();
Type *ResultTy = Call.getType();
Check(!ValTy->isVectorTy() && !ResultTy->isVectorTy(),
"Intrinsic does not support vectors", &Call);
auto *VTy = dyn_cast<VectorType>(ValTy);
auto *RTy = dyn_cast<VectorType>(ResultTy);
Check(
ValTy->isFPOrFPVectorTy() && ResultTy->isIntOrIntVectorTy(),
"llvm.lround, llvm.llround: argument must be floating-point or vector "
"of floating-points, and result must be integer or vector of integers",
&Call);
Check(
ValTy->isVectorTy() == ResultTy->isVectorTy(),
"llvm.lround, llvm.llround: argument and result disagree on vector use",
&Call);
if (VTy) {
Check(VTy->getElementCount() == RTy->getElementCount(),
"llvm.lround, llvm.llround: argument must be same length as result",
&Call);
}
break;
}
case Intrinsic::bswap: {
Expand Down
26 changes: 26 additions & 0 deletions llvm/test/Assembler/lround.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
; Validate that vector types are accepted for llvm.lround/llvm.llround intrinsic
; RUN: llvm-as < %s | llvm-dis | FileCheck %s

define <2 x i32> @intrinsic_lround_v2i32_v2f32(<2 x float> %arg) {
;CHECK: %res = tail call <2 x i32> @llvm.lround.v2i32.v2f32(<2 x float> %arg)
%res = tail call <2 x i32> @llvm.lround.v2i32.v2f32(<2 x float> %arg)
ret <2 x i32> %res
}

define <2 x i32> @intrinsic_llround_v2i32_v2f32(<2 x float> %arg) {
;CHECK: %res = tail call <2 x i32> @llvm.llround.v2i32.v2f32(<2 x float> %arg)
%res = tail call <2 x i32> @llvm.llround.v2i32.v2f32(<2 x float> %arg)
ret <2 x i32> %res
}

define <2 x i64> @intrinsic_lround_v2i64_v2f32(<2 x float> %arg) {
;CHECK: %res = tail call <2 x i64> @llvm.lround.v2i64.v2f32(<2 x float> %arg)
%res = tail call <2 x i64> @llvm.lround.v2i64.v2f32(<2 x float> %arg)
ret <2 x i64> %res
}

define <2 x i64> @intrinsic_llround_v2i64_v2f32(<2 x float> %arg) {
;CHECK: %res = tail call <2 x i64> @llvm.llround.v2i64.v2f32(<2 x float> %arg)
%res = tail call <2 x i64> @llvm.llround.v2i64.v2f32(<2 x float> %arg)
ret <2 x i64> %res
}
Loading