Skip to content

Commit 28cdbbe

Browse files
[AArch64] Combine getActiveLaneMask with vector_extract
... into a `whilelo` instruction with a pair of predicate registers.
1 parent 8511b32 commit 28cdbbe

File tree

3 files changed

+515
-35
lines changed

3 files changed

+515
-35
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 87 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,8 +1815,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
18151815

18161816
bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
18171817
EVT OpVT) const {
1818-
// Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
1819-
if (!Subtarget->hasSVE())
1818+
// Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
1819+
if (!Subtarget->hasSVEorSME())
18201820
return true;
18211821

18221822
// We can only support legal predicate result types. We can use the SVE
@@ -20054,47 +20054,99 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
2005420054
return SDValue();
2005520055
}
2005620056

20057-
static SDValue performIntrinsicCombine(SDNode *N,
20058-
TargetLowering::DAGCombinerInfo &DCI,
20059-
const AArch64Subtarget *Subtarget) {
20057+
static SDValue tryCombineGetActiveLaneMask(SDNode *N,
20058+
TargetLowering::DAGCombinerInfo &DCI,
20059+
const AArch64Subtarget *Subtarget) {
2006020060
SelectionDAG &DAG = DCI.DAG;
20061-
unsigned IID = getIntrinsicID(N);
20062-
switch (IID) {
20063-
default:
20064-
break;
20065-
case Intrinsic::get_active_lane_mask: {
20066-
SDValue Res = SDValue();
20067-
EVT VT = N->getValueType(0);
20068-
if (VT.isFixedLengthVector()) {
20069-
// We can use the SVE whilelo instruction to lower this intrinsic by
20070-
// creating the appropriate sequence of scalable vector operations and
20071-
// then extracting a fixed-width subvector from the scalable vector.
20061+
EVT VT = N->getValueType(0);
20062+
if (VT.isFixedLengthVector()) {
20063+
// We can use the SVE whilelo instruction to lower this intrinsic by
20064+
// creating the appropriate sequence of scalable vector operations and
20065+
// then extracting a fixed-width subvector from the scalable vector.
20066+
SDLoc DL(N);
20067+
SDValue ID =
20068+
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
2007220069

20073-
SDLoc DL(N);
20074-
SDValue ID =
20075-
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
20070+
EVT WhileVT =
20071+
EVT::getVectorVT(*DAG.getContext(), MVT::i1,
20072+
ElementCount::getScalable(VT.getVectorNumElements()));
2007620073

20077-
EVT WhileVT = EVT::getVectorVT(
20078-
*DAG.getContext(), MVT::i1,
20079-
ElementCount::getScalable(VT.getVectorNumElements()));
20074+
// Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
20075+
EVT PromVT = getPromotedVTForPredicate(WhileVT);
2008020076

20081-
// Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
20082-
EVT PromVT = getPromotedVTForPredicate(WhileVT);
20077+
// Get the fixed-width equivalent of PromVT for extraction.
20078+
EVT ExtVT =
20079+
EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
20080+
VT.getVectorElementCount());
2008320081

20084-
// Get the fixed-width equivalent of PromVT for extraction.
20085-
EVT ExtVT =
20086-
EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
20087-
VT.getVectorElementCount());
20082+
SDValue Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
20083+
N->getOperand(1), N->getOperand(2));
20084+
Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
20085+
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
20086+
DAG.getConstant(0, DL, MVT::i64));
20087+
Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
2008820088

20089-
Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
20090-
N->getOperand(1), N->getOperand(2));
20091-
Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
20092-
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
20093-
DAG.getConstant(0, DL, MVT::i64));
20094-
Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
20095-
}
2009620089
return Res;
2009720090
}
20091+
20092+
const uint64_t HalfSize = VT.getVectorMinNumElements() / 2;
20093+
if (HalfSize < 2)
20094+
return SDValue();
20095+
20096+
if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
20097+
return SDValue();
20098+
20099+
if (!N->hasNUsesOfValue(2, 0))
20100+
return SDValue();
20101+
20102+
auto It = N->use_begin();
20103+
SDNode *Lo = *It++;
20104+
SDNode *Hi = *It;
20105+
20106+
uint64_t OffLo, OffHi;
20107+
if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
20108+
!isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
20109+
Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
20110+
!isIntImmediate(Hi->getOperand(1).getNode(), OffHi))
20111+
return SDValue();
20112+
20113+
if (OffLo > OffHi) {
20114+
std::swap(Lo, Hi);
20115+
std::swap(OffLo, OffHi);
20116+
}
20117+
20118+
if (OffLo != 0 || OffHi != HalfSize)
20119+
return SDValue();
20120+
20121+
SDLoc DL(N);
20122+
SDValue ID =
20123+
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
20124+
SDValue Idx = N->getOperand(1);
20125+
SDValue TC = N->getOperand(2);
20126+
if (Idx.getValueType() != MVT::i64) {
20127+
Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
20128+
TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
20129+
}
20130+
auto R =
20131+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
20132+
{Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
20133+
20134+
DCI.CombineTo(Lo, R.getValue(0));
20135+
DCI.CombineTo(Hi, R.getValue(1));
20136+
20137+
return SDValue(N, 0);
20138+
}
20139+
20140+
static SDValue performIntrinsicCombine(SDNode *N,
20141+
TargetLowering::DAGCombinerInfo &DCI,
20142+
const AArch64Subtarget *Subtarget) {
20143+
SelectionDAG &DAG = DCI.DAG;
20144+
unsigned IID = getIntrinsicID(N);
20145+
switch (IID) {
20146+
default:
20147+
break;
20148+
case Intrinsic::get_active_lane_mask:
20149+
return tryCombineGetActiveLaneMask(N, DCI, Subtarget);
2009820150
case Intrinsic::aarch64_neon_vcvtfxs2fp:
2009920151
case Intrinsic::aarch64_neon_vcvtfxu2fp:
2010020152
return tryCombineFixedPointConvert(N, DCI, DAG);

llvm/test/CodeGen/AArch64/active_lane_mask.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
22
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
3+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme < %s | FileCheck %s
34

45
; == Scalable ==
56

0 commit comments

Comments
 (0)