Skip to content

Commit 0f9031c

Browse files
[AArch64] Combine getActiveLaneMask with vector_extract
... into a `whilelo` instruction with a pair of predicate registers.
1 parent 8461d90 commit 0f9031c

File tree

3 files changed

+455
-2
lines changed

3 files changed

+455
-2
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,8 +1847,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
18471847

18481848
bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
18491849
EVT OpVT) const {
1850-
// Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
1851-
if (!Subtarget->hasSVE())
1850+
// Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
1851+
if (!Subtarget->hasSVEorSME())
18521852
return true;
18531853

18541854
// We can only support legal predicate result types. We can use the SVE
@@ -20481,6 +20481,61 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
2048120481
return SDValue();
2048220482
}
2048320483

20484+
static SDValue tryCombineWhileLo(SDNode *N,
20485+
TargetLowering::DAGCombinerInfo &DCI,
20486+
const AArch64Subtarget *Subtarget) {
20487+
if (DCI.isBeforeLegalize())
20488+
return SDValue();
20489+
20490+
if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
20491+
return SDValue();
20492+
20493+
if (!N->hasNUsesOfValue(2, 0))
20494+
return SDValue();
20495+
20496+
const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
20497+
if (HalfSize < 2)
20498+
return SDValue();
20499+
20500+
auto It = N->use_begin();
20501+
SDNode *Lo = *It++;
20502+
SDNode *Hi = *It;
20503+
20504+
uint64_t OffLo, OffHi;
20505+
if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
20506+
!isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
20507+
Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
20508+
!isIntImmediate(Hi->getOperand(1).getNode(), OffHi))
20509+
return SDValue();
20510+
20511+
if (OffLo > OffHi) {
20512+
std::swap(Lo, Hi);
20513+
std::swap(OffLo, OffHi);
20514+
}
20515+
20516+
if (OffLo != 0 || OffHi != HalfSize)
20517+
return SDValue();
20518+
20519+
SelectionDAG &DAG = DCI.DAG;
20520+
SDLoc DL(N);
20521+
SDValue ID =
20522+
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
20523+
SDValue Idx = N->getOperand(1);
20524+
SDValue TC = N->getOperand(2);
20525+
if (Idx.getValueType() != MVT::i64) {
20526+
Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
20527+
TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
20528+
}
20529+
auto R =
20530+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
20531+
{Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
20532+
20533+
DCI.CombineTo(Lo, R.getValue(0));
20534+
DCI.CombineTo(Hi, R.getValue(1));
20535+
20536+
return SDValue(N, 0);
20537+
}
20538+
2048420539
static SDValue performIntrinsicCombine(SDNode *N,
2048520540
TargetLowering::DAGCombinerInfo &DCI,
2048620541
const AArch64Subtarget *Subtarget) {
@@ -20811,6 +20866,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
2081120866
case Intrinsic::aarch64_sve_ptest_last:
2081220867
return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
2081320868
AArch64CC::LAST_ACTIVE);
20869+
case Intrinsic::aarch64_sve_whilelo:
20870+
return tryCombineWhileLo(N, DCI, Subtarget);
2081420871
}
2081520872
return SDValue();
2081620873
}

llvm/test/CodeGen/AArch64/active_lane_mask.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
22
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
3+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme < %s | FileCheck %s
34

45
; == Scalable ==
56

0 commit comments

Comments
 (0)