Skip to content

Commit 2c0a226

Browse files
authored
[AArch64] Spare N2I roundtrip when splatting float comparison (#141806)
Transform `select_cc t1, t2, -1, 0` for floats into a vector comparison which generates a mask, which is later on combined with potential vectorized DUPs.
1 parent 56ebe64 commit 2c0a226

File tree

4 files changed

+557
-56
lines changed

4 files changed

+557
-56
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 173 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11048,10 +11048,126 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
1104811048
Cmp.getValue(1));
1104911049
}
1105011050

11051-
SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
11052-
SDValue RHS, SDValue TVal,
11053-
SDValue FVal, const SDLoc &dl,
11054-
SelectionDAG &DAG) const {
11051+
/// Emit vector comparison for floating-point values, producing a mask.
11052+
static SDValue emitVectorComparison(SDValue LHS, SDValue RHS,
11053+
AArch64CC::CondCode CC, bool NoNans, EVT VT,
11054+
const SDLoc &DL, SelectionDAG &DAG) {
11055+
EVT SrcVT = LHS.getValueType();
11056+
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
11057+
"function only supposed to emit natural comparisons");
11058+
11059+
switch (CC) {
11060+
default:
11061+
return SDValue();
11062+
case AArch64CC::NE: {
11063+
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, DL, VT, LHS, RHS);
11064+
// Use vector semantics for the inversion to potentially save a copy between
11065+
// SIMD and regular registers.
11066+
if (!LHS.getValueType().isVector()) {
11067+
EVT VecVT =
11068+
EVT::getVectorVT(*DAG.getContext(), VT, 128 / VT.getSizeInBits());
11069+
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
11070+
SDValue MaskVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT,
11071+
DAG.getUNDEF(VecVT), Fcmeq, Zero);
11072+
SDValue InvertedMask = DAG.getNOT(DL, MaskVec, VecVT);
11073+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, InvertedMask, Zero);
11074+
}
11075+
return DAG.getNOT(DL, Fcmeq, VT);
11076+
}
11077+
case AArch64CC::EQ:
11078+
return DAG.getNode(AArch64ISD::FCMEQ, DL, VT, LHS, RHS);
11079+
case AArch64CC::GE:
11080+
return DAG.getNode(AArch64ISD::FCMGE, DL, VT, LHS, RHS);
11081+
case AArch64CC::GT:
11082+
return DAG.getNode(AArch64ISD::FCMGT, DL, VT, LHS, RHS);
11083+
case AArch64CC::LE:
11084+
if (!NoNans)
11085+
return SDValue();
11086+
// If we ignore NaNs then we can use to the LS implementation.
11087+
[[fallthrough]];
11088+
case AArch64CC::LS:
11089+
return DAG.getNode(AArch64ISD::FCMGE, DL, VT, RHS, LHS);
11090+
case AArch64CC::LT:
11091+
if (!NoNans)
11092+
return SDValue();
11093+
// If we ignore NaNs then we can use to the MI implementation.
11094+
[[fallthrough]];
11095+
case AArch64CC::MI:
11096+
return DAG.getNode(AArch64ISD::FCMGT, DL, VT, RHS, LHS);
11097+
}
11098+
}
11099+
11100+
/// For SELECT_CC, when the true/false values are (-1, 0) and the compared
11101+
/// values are scalars, try to emit a mask generating vector instruction.
11102+
static SDValue emitFloatCompareMask(SDValue LHS, SDValue RHS, SDValue TVal,
11103+
SDValue FVal, ISD::CondCode CC, bool NoNaNs,
11104+
const SDLoc &DL, SelectionDAG &DAG) {
11105+
assert(!LHS.getValueType().isVector());
11106+
assert(!RHS.getValueType().isVector());
11107+
11108+
auto *CTVal = dyn_cast<ConstantSDNode>(TVal);
11109+
auto *CFVal = dyn_cast<ConstantSDNode>(FVal);
11110+
if (!CTVal || !CFVal)
11111+
return {};
11112+
if (!(CTVal->isAllOnes() && CFVal->isZero()) &&
11113+
!(CTVal->isZero() && CFVal->isAllOnes()))
11114+
return {};
11115+
11116+
if (CTVal->isZero())
11117+
CC = ISD::getSetCCInverse(CC, LHS.getValueType());
11118+
11119+
EVT VT = TVal.getValueType();
11120+
if (VT.getSizeInBits() != LHS.getValueType().getSizeInBits())
11121+
return {};
11122+
11123+
if (!NoNaNs && (CC == ISD::SETUO || CC == ISD::SETO)) {
11124+
bool OneNaN = false;
11125+
if (LHS == RHS) {
11126+
OneNaN = true;
11127+
} else if (DAG.isKnownNeverNaN(RHS)) {
11128+
OneNaN = true;
11129+
RHS = LHS;
11130+
} else if (DAG.isKnownNeverNaN(LHS)) {
11131+
OneNaN = true;
11132+
LHS = RHS;
11133+
}
11134+
if (OneNaN)
11135+
CC = (CC == ISD::SETUO) ? ISD::SETUNE : ISD::SETOEQ;
11136+
}
11137+
11138+
AArch64CC::CondCode CC1;
11139+
AArch64CC::CondCode CC2;
11140+
bool ShouldInvert = false;
11141+
changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
11142+
SDValue Cmp = emitVectorComparison(LHS, RHS, CC1, NoNaNs, VT, DL, DAG);
11143+
SDValue Cmp2;
11144+
if (CC2 != AArch64CC::AL) {
11145+
Cmp2 = emitVectorComparison(LHS, RHS, CC2, NoNaNs, VT, DL, DAG);
11146+
if (!Cmp2)
11147+
return {};
11148+
}
11149+
if (!Cmp2 && !ShouldInvert)
11150+
return Cmp;
11151+
11152+
EVT VecVT = EVT::getVectorVT(*DAG.getContext(), VT, 128 / VT.getSizeInBits());
11153+
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
11154+
Cmp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, DAG.getUNDEF(VecVT), Cmp,
11155+
Zero);
11156+
if (Cmp2) {
11157+
Cmp2 = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, DAG.getUNDEF(VecVT),
11158+
Cmp2, Zero);
11159+
Cmp = DAG.getNode(ISD::OR, DL, VecVT, Cmp, Cmp2);
11160+
}
11161+
if (ShouldInvert)
11162+
Cmp = DAG.getNOT(DL, Cmp, VecVT);
11163+
Cmp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Cmp, Zero);
11164+
return Cmp;
11165+
}
11166+
11167+
SDValue AArch64TargetLowering::LowerSELECT_CC(
11168+
ISD::CondCode CC, SDValue LHS, SDValue RHS, SDValue TVal, SDValue FVal,
11169+
iterator_range<SDNode::user_iterator> Users, bool HasNoNaNs,
11170+
const SDLoc &dl, SelectionDAG &DAG) const {
1105511171
// Handle f128 first, because it will result in a comparison of some RTLIB
1105611172
// call result against zero.
1105711173
if (LHS.getValueType() == MVT::f128) {
@@ -11234,6 +11350,27 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
1123411350
LHS.getValueType() == MVT::f64);
1123511351
assert(LHS.getValueType() == RHS.getValueType());
1123611352
EVT VT = TVal.getValueType();
11353+
11354+
// If the purpose of the comparison is to select between all ones
11355+
// or all zeros, try to use a vector comparison because the operands are
11356+
// already stored in SIMD registers.
11357+
if (Subtarget->isNeonAvailable() && all_of(Users, [](const SDNode *U) {
11358+
switch (U->getOpcode()) {
11359+
default:
11360+
return false;
11361+
case ISD::INSERT_VECTOR_ELT:
11362+
case ISD::SCALAR_TO_VECTOR:
11363+
case AArch64ISD::DUP:
11364+
return true;
11365+
}
11366+
})) {
11367+
bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || HasNoNaNs;
11368+
SDValue VectorCmp =
11369+
emitFloatCompareMask(LHS, RHS, TVal, FVal, CC, NoNaNs, dl, DAG);
11370+
if (VectorCmp)
11371+
return VectorCmp;
11372+
}
11373+
1123711374
SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
1123811375

1123911376
// Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
@@ -11320,15 +11457,18 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
1132011457
SDValue RHS = Op.getOperand(1);
1132111458
SDValue TVal = Op.getOperand(2);
1132211459
SDValue FVal = Op.getOperand(3);
11460+
bool HasNoNans = Op->getFlags().hasNoNaNs();
1132311461
SDLoc DL(Op);
11324-
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11462+
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), HasNoNans, DL,
11463+
DAG);
1132511464
}
1132611465

1132711466
SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1132811467
SelectionDAG &DAG) const {
1132911468
SDValue CCVal = Op->getOperand(0);
1133011469
SDValue TVal = Op->getOperand(1);
1133111470
SDValue FVal = Op->getOperand(2);
11471+
bool HasNoNans = Op->getFlags().hasNoNaNs();
1133211472
SDLoc DL(Op);
1133311473

1133411474
EVT Ty = Op.getValueType();
@@ -11395,7 +11535,8 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1139511535
DAG.getUNDEF(MVT::f32), FVal);
1139611536
}
1139711537

11398-
SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11538+
SDValue Res =
11539+
LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), HasNoNans, DL, DAG);
1139911540

1140011541
if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
1140111542
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
@@ -15648,47 +15789,6 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1564815789
llvm_unreachable("unexpected shift opcode");
1564915790
}
1565015791

15651-
static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
15652-
AArch64CC::CondCode CC, bool NoNans, EVT VT,
15653-
const SDLoc &dl, SelectionDAG &DAG) {
15654-
EVT SrcVT = LHS.getValueType();
15655-
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
15656-
"function only supposed to emit natural comparisons");
15657-
15658-
if (SrcVT.getVectorElementType().isFloatingPoint()) {
15659-
switch (CC) {
15660-
default:
15661-
return SDValue();
15662-
case AArch64CC::NE: {
15663-
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15664-
return DAG.getNOT(dl, Fcmeq, VT);
15665-
}
15666-
case AArch64CC::EQ:
15667-
return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15668-
case AArch64CC::GE:
15669-
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
15670-
case AArch64CC::GT:
15671-
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
15672-
case AArch64CC::LE:
15673-
if (!NoNans)
15674-
return SDValue();
15675-
// If we ignore NaNs then we can use to the LS implementation.
15676-
[[fallthrough]];
15677-
case AArch64CC::LS:
15678-
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
15679-
case AArch64CC::LT:
15680-
if (!NoNans)
15681-
return SDValue();
15682-
// If we ignore NaNs then we can use to the MI implementation.
15683-
[[fallthrough]];
15684-
case AArch64CC::MI:
15685-
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
15686-
}
15687-
}
15688-
15689-
return SDValue();
15690-
}
15691-
1569215792
SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
1569315793
SelectionDAG &DAG) const {
1569415794
if (Op.getValueType().isScalableVector())
@@ -15737,15 +15837,14 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
1573715837
bool ShouldInvert;
1573815838
changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
1573915839

15740-
bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
15741-
SDValue Cmp =
15742-
EmitVectorComparison(LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
15840+
bool NoNaNs =
15841+
getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
15842+
SDValue Cmp = emitVectorComparison(LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
1574315843
if (!Cmp.getNode())
1574415844
return SDValue();
1574515845

1574615846
if (CC2 != AArch64CC::AL) {
15747-
SDValue Cmp2 =
15748-
EmitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
15847+
SDValue Cmp2 = emitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
1574915848
if (!Cmp2.getNode())
1575015849
return SDValue();
1575115850

@@ -25502,6 +25601,28 @@ static SDValue performDUPCombine(SDNode *N,
2550225601
}
2550325602

2550425603
if (N->getOpcode() == AArch64ISD::DUP) {
25604+
// If the instruction is known to produce a scalar in SIMD registers, we can
25605+
// duplicate it across the vector lanes using DUPLANE instead of moving it
25606+
// to a GPR first. For example, this allows us to handle:
25607+
// v4i32 = DUP (i32 (FCMGT (f32, f32)))
25608+
SDValue Op = N->getOperand(0);
25609+
// FIXME: Ideally, we should be able to handle all instructions that
25610+
// produce a scalar value in FPRs.
25611+
if (Op.getOpcode() == AArch64ISD::FCMEQ ||
25612+
Op.getOpcode() == AArch64ISD::FCMGE ||
25613+
Op.getOpcode() == AArch64ISD::FCMGT) {
25614+
EVT ElemVT = VT.getVectorElementType();
25615+
EVT ExpandedVT = VT;
25616+
// Insert into a 128-bit vector to match DUPLANE's pattern.
25617+
if (VT.getSizeInBits() != 128)
25618+
ExpandedVT = EVT::getVectorVT(*DCI.DAG.getContext(), ElemVT,
25619+
128 / ElemVT.getSizeInBits());
25620+
SDValue Zero = DCI.DAG.getConstant(0, DL, MVT::i64);
25621+
SDValue Vec = DCI.DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpandedVT,
25622+
DCI.DAG.getUNDEF(ExpandedVT), Op, Zero);
25623+
return DCI.DAG.getNode(getDUPLANEOp(ElemVT), DL, VT, Vec, Zero);
25624+
}
25625+
2550525626
if (DCI.isAfterLegalizeDAG()) {
2550625627
// If scalar dup's operand is extract_vector_elt, try to combine them into
2550725628
// duplane. For example,

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,9 @@ class AArch64TargetLowering : public TargetLowering {
647647
SDValue LowerSELECT(SDValue Op, SelectionDAG &DAG) const;
648648
SDValue LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const;
649649
SDValue LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, SDValue RHS,
650-
SDValue TVal, SDValue FVal, const SDLoc &dl,
650+
SDValue TVal, SDValue FVal,
651+
iterator_range<SDNode::user_iterator> Users,
652+
bool HasNoNans, const SDLoc &dl,
651653
SelectionDAG &DAG) const;
652654
SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
653655
SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ define <1 x i16> @test_select_f16_i16(half %i105, half %in, <1 x i16> %x, <1 x i
174174
; CHECK-LABEL: test_select_f16_i16:
175175
; CHECK: // %bb.0:
176176
; CHECK-NEXT: fcvt s0, h0
177-
; CHECK-NEXT: fcmp s0, s0
178-
; CHECK-NEXT: csetm w8, vs
179-
; CHECK-NEXT: dup v0.4h, w8
177+
; CHECK-NEXT: fcmeq s0, s0, s0
178+
; CHECK-NEXT: mvn v0.16b, v0.16b
179+
; CHECK-NEXT: dup v0.4h, v0.h[0]
180180
; CHECK-NEXT: bsl v0.8b, v2.8b, v3.8b
181181
; CHECK-NEXT: ret
182182
%i179 = fcmp uno half %i105, zeroinitializer

0 commit comments

Comments
 (0)