Skip to content

Commit c943b04

Browse files
committed
[AArch64] Lower alias mask to a whilewr
llvm#100579 emits IR that creates a mask disabling lanes that could alias within a loop iteration, based on a pair of pointers. This PR lowers that IR to a WHILEWR instruction for AArch64.
1 parent c7a3346 commit c943b04

File tree

2 files changed

+966
-0
lines changed

2 files changed

+966
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
#include <bitset>
9595
#include <cassert>
9696
#include <cctype>
97+
#include <cmath>
9798
#include <cstdint>
9899
#include <cstdlib>
99100
#include <iterator>
@@ -1523,6 +1524,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15231524
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
15241525
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
15251526
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1527+
setOperationAction(ISD::OR, VT, Custom);
15261528

15271529
setOperationAction(ISD::SELECT_CC, VT, Expand);
15281530
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
@@ -13782,8 +13784,88 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
1378213784
return ResultSLI;
1378313785
}
1378413786

13787+
/// Try to lower the construction of a pointer alias mask to a WHILEWR.
13788+
/// The mask's enabled lanes represent the elements that will not overlap across one loop iteration.
13789+
/// This tries to match:
13790+
/// or (splat (setcc_lt (sub ptrA, ptrB), -(element_size - 1))),
13791+
/// (get_active_lane_mask 0, (div (sub ptrA, ptrB), element_size))
13792+
SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG) {
13793+
if (!DAG.getSubtarget<AArch64Subtarget>().hasSVE2())
13794+
return SDValue();
13795+
auto LaneMask = Op.getOperand(0);
13796+
auto Splat = Op.getOperand(1);
13797+
13798+
if (LaneMask.getOpcode() != ISD::INTRINSIC_WO_CHAIN ||
13799+
LaneMask.getConstantOperandVal(0) != Intrinsic::get_active_lane_mask ||
13800+
Splat.getOpcode() != ISD::SPLAT_VECTOR)
13801+
return SDValue();
13802+
13803+
auto Cmp = Splat.getOperand(0);
13804+
if (Cmp.getOpcode() != ISD::SETCC)
13805+
return SDValue();
13806+
13807+
CondCodeSDNode *Cond = dyn_cast<CondCodeSDNode>(Cmp.getOperand(2));
13808+
assert(Cond && "SETCC doesn't have a condition code");
13809+
13810+
auto ComparatorConst = dyn_cast<ConstantSDNode>(Cmp.getOperand(1));
13811+
if (!ComparatorConst || ComparatorConst->getSExtValue() > 0 ||
13812+
Cond->get() != ISD::CondCode::SETLT)
13813+
return SDValue();
13814+
unsigned CompValue = std::abs(ComparatorConst->getSExtValue());
13815+
unsigned EltSize = CompValue + 1;
13816+
if (!isPowerOf2_64(EltSize) || EltSize > 64)
13817+
return SDValue();
13818+
13819+
auto Diff = Cmp.getOperand(0);
13820+
if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64)
13821+
return SDValue();
13822+
13823+
auto LaneMaskConst = dyn_cast<ConstantSDNode>(LaneMask.getOperand(1));
13824+
if (!LaneMaskConst || LaneMaskConst->getZExtValue() != 0 ||
13825+
(EltSize != 1 && LaneMask.getOperand(2).getOpcode() != ISD::SRA))
13826+
return SDValue();
13827+
13828+
// An alias mask for i8 elements omits the division because it would just divide by 1
13829+
if (EltSize > 1) {
13830+
auto DiffDiv = LaneMask.getOperand(2);
13831+
auto DiffDivConst = dyn_cast<ConstantSDNode>(DiffDiv.getOperand(1));
13832+
if (!DiffDivConst || DiffDivConst->getZExtValue() != std::log2(EltSize))
13833+
return SDValue();
13834+
} else if (LaneMask.getOperand(2) != Diff)
13835+
return SDValue();
13836+
13837+
auto StorePtr = Diff.getOperand(0);
13838+
auto ReadPtr = Diff.getOperand(1);
13839+
13840+
unsigned IntrinsicID = 0;
13841+
switch (EltSize) {
13842+
case 1:
13843+
IntrinsicID = Intrinsic::aarch64_sve_whilewr_b;
13844+
break;
13845+
case 2:
13846+
IntrinsicID = Intrinsic::aarch64_sve_whilewr_h;
13847+
break;
13848+
case 4:
13849+
IntrinsicID = Intrinsic::aarch64_sve_whilewr_s;
13850+
break;
13851+
case 8:
13852+
IntrinsicID = Intrinsic::aarch64_sve_whilewr_d;
13853+
break;
13854+
default:
13855+
return SDValue();
13856+
}
13857+
SDLoc DL(Op);
13858+
SDValue ID = DAG.getConstant(IntrinsicID, DL, MVT::i32);
13859+
auto N = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), ID,
13860+
StorePtr, ReadPtr);
13861+
return N;
13862+
}
13863+
1378513864
SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
1378613865
SelectionDAG &DAG) const {
13866+
13867+
if (SDValue SV = tryWhileWRFromOR(Op, DAG))
13868+
return SV;
1378713869
if (useSVEForFixedLengthVectorVT(Op.getValueType(),
1378813870
!Subtarget->isNeonAvailable()))
1378913871
return LowerToScalableOp(Op, DAG);

0 commit comments

Comments
 (0)