Skip to content

Commit 57f8f22

Browse files
committed
[Intrinsics][AArch64] Add intrinsic to mask off aliasing vector lanes
It can be unsafe to load a vector from an address and write a vector to an address if those two addresses have overlapping lanes within a vectorised loop iteration. This PR adds an intrinsic designed to create a mask with lanes disabled if they overlap between the two pointer arguments, so that only safe lanes are loaded, operated on and stored. Along with the two pointer parameters, the intrinsic also takes an immediate that represents the size in bytes of the vector element types, as well as an immediate i1 that is true if there is a write after-read-hazard or false if there is a read-after-write hazard. This will be used by llvm#100579 and replaces the existing lowering for whilewr since that isn't needed now we have the intrinsic.
1 parent 9cf2465 commit 57f8f22

File tree

10 files changed

+715
-11
lines changed

10 files changed

+715
-11
lines changed

llvm/docs/LangRef.rst

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23624,6 +23624,86 @@ Examples:
2362423624
%active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i64(i64 %elem0, i64 429)
2362523625
%wide.masked.load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %3, i32 4, <4 x i1> %active.lane.mask, <4 x i32> poison)
2362623626

23627+
.. _int_experimental_get_alias_lane_mask:
23628+
23629+
'``llvm.get.alias.lane.mask.*``' Intrinsics
23630+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
23631+
23632+
Syntax:
23633+
"""""""
23634+
This is an overloaded intrinsic.
23635+
23636+
::
23637+
23638+
declare <4 x i1> @llvm.experimental.get.alias.lane.mask.v4i1.i64(i64 %ptrA, i64 %ptrB, i32 immarg %elementSize, i1 immarg %writeAfterRead)
23639+
declare <8 x i1> @llvm.experimental.get.alias.lane.mask.v8i1.i64(i64 %ptrA, i64 %ptrB, i32 immarg %elementSize, i1 immarg %writeAfterRead)
23640+
declare <16 x i1> @llvm.experimental.get.alias.lane.mask.v16i1.i64(i64 %ptrA, i64 %ptrB, i32 immarg %elementSize, i1 immarg %writeAfterRead)
23641+
declare <vscale x 16 x i1> @llvm.experimental.get.alias.lane.mask.nxv16i1.i64(i64 %ptrA, i64 %ptrB, i32 immarg %elementSize, i1 immarg %writeAfterRead)
23642+
23643+
23644+
Overview:
23645+
"""""""""
23646+
23647+
Create a mask representing lanes that do or not overlap between two pointers across one vector loop iteration.
23648+
23649+
23650+
Arguments:
23651+
""""""""""
23652+
23653+
The first two arguments have the same scalar integer type.
23654+
The final two are immediates and the result is a vector with the i1 element type.
23655+
23656+
Semantics:
23657+
""""""""""
23658+
23659+
In the case that ``%writeAfterRead`` is true, the '``llvm.experimental.get.alias.lane.mask.*``' intrinsics are semantically equivalent
23660+
to:
23661+
23662+
::
23663+
23664+
%diff = (%ptrB - %ptrA) / %elementSize
23665+
%m[i] = (icmp ult i, %diff) || (%diff <= 0)
23666+
23667+
Otherwise they are semantically equivalent to:
23668+
23669+
::
23670+
23671+
%diff = abs(%ptrB - %ptrA) / %elementSize
23672+
%m[i] = (icmp ult i, %diff) || (%diff == 0)
23673+
23674+
where ``%m`` is a vector (mask) of active/inactive lanes with its elements
23675+
indexed by ``i``, and ``%ptrA``, ``%ptrB`` are the two i64 arguments to
23676+
``llvm.experimental.get.alias.lane.mask.*``, ``%elementSize`` is the i32 argument, ``%abs`` is the absolute difference operation, ``%icmp`` is an integer compare and ``ult``
23677+
the unsigned less-than comparison operator. The subtraction between ``%ptrA`` and ``%ptrB`` could be negative. The ``%writeAfterRead`` argument is expected to be true if the ``%ptrB`` is stored to after ``%ptrA`` is read from.
23678+
The above is equivalent to:
23679+
23680+
::
23681+
23682+
%m = @llvm.experimental.get.alias.lane.mask(%ptrA, %ptrB, %elementSize, %writeAfterRead)
23683+
23684+
This can, for example, be emitted by the loop vectorizer in which case
23685+
``%ptrA`` is a pointer that is read from within the loop, and ``%ptrB`` is a pointer that is stored to within the loop.
23686+
If the difference between these pointers is less than the vector factor, then they overlap (alias) within a loop iteration.
23687+
An example is if ``%ptrA`` is 20 and ``%ptrB`` is 23 with a vector factor of 8, then lanes 3, 4, 5, 6 and 7 of the vector loaded from ``%ptrA``
23688+
share addresses with lanes 0, 1, 2, 3, 4 and 5 from the vector stored to at ``%ptrB``.
23689+
An alias mask of these two pointers should be <1, 1, 1, 0, 0, 0, 0, 0> so that only the non-overlapping lanes are loaded and stored.
23690+
This operation allows many loops to be vectorised when it would otherwise be unsafe to do so.
23691+
23692+
To account for the fact that only a subset of lanes have been operated on in an iteration,
23693+
the loop's induction variable should be incremented by the popcount of the mask rather than the vector factor.
23694+
23695+
This mask ``%m`` can e.g. be used in masked load/store instructions.
23696+
23697+
23698+
Examples:
23699+
"""""""""
23700+
23701+
.. code-block:: llvm
23702+
23703+
%alias.lane.mask = call <4 x i1> @llvm.experimental.get.alias.lane.mask.v4i1.i64(i64 %ptrA, i64 %ptrB, i32 4, i1 1)
23704+
%vecA = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptrA, i32 4, <4 x i1> %alias.lane.mask, <4 x i32> poison)
23705+
[...]
23706+
call @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %vecA, <4 x i32>* %ptrB, i32 4, <4 x i1> %alias.lane.mask)
2362723707

2362823708
.. _int_experimental_vp_splice:
2362923709

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,11 @@ class TargetLoweringBase {
468468
return true;
469469
}
470470

471+
/// Return true if the @llvm.experimental.get.alias.lane.mask intrinsic should be expanded using generic code in SelectionDAGBuilder.
472+
virtual bool shouldExpandGetAliasLaneMask(EVT VT, EVT PtrVT, unsigned EltSize) const {
473+
return true;
474+
}
475+
471476
virtual bool shouldExpandGetVectorLength(EVT CountVT, unsigned VF,
472477
bool IsScalable) const {
473478
return true;

llvm/include/llvm/IR/Intrinsics.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2363,6 +2363,11 @@ let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn, ImmArg<ArgIndex<1>>
23632363
llvm_i32_ty]>;
23642364
}
23652365

2366+
def int_experimental_get_alias_lane_mask:
2367+
DefaultAttrsIntrinsic<[llvm_anyvector_ty],
2368+
[llvm_anyint_ty, LLVMMatchType<1>, llvm_anyint_ty, llvm_i1_ty],
2369+
[IntrNoMem, IntrNoSync, IntrWillReturn, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<3>>]>;
2370+
23662371
def int_get_active_lane_mask:
23672372
DefaultAttrsIntrinsic<[llvm_anyvector_ty],
23682373
[llvm_anyint_ty, LLVMMatchType<1>],

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8276,6 +8276,50 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
82768276
visitVectorExtractLastActive(I, Intrinsic);
82778277
return;
82788278
}
8279+
case Intrinsic::experimental_get_alias_lane_mask: {
8280+
SDValue SourceValue = getValue(I.getOperand(0));
8281+
SDValue SinkValue = getValue(I.getOperand(1));
8282+
SDValue EltSize = getValue(I.getOperand(2));
8283+
bool IsWriteAfterRead = cast<ConstantSDNode>(getValue(I.getOperand(3)))->getZExtValue() != 0;
8284+
auto IntrinsicVT = EVT::getEVT(I.getType());
8285+
auto PtrVT = SourceValue->getValueType(0);
8286+
8287+
if (!TLI.shouldExpandGetAliasLaneMask(IntrinsicVT, PtrVT, cast<ConstantSDNode>(EltSize)->getSExtValue())) {
8288+
visitTargetIntrinsic(I, Intrinsic);
8289+
return;
8290+
}
8291+
8292+
SDValue Diff = DAG.getNode(ISD::SUB, sdl,
8293+
PtrVT, SinkValue, SourceValue);
8294+
if (!IsWriteAfterRead)
8295+
Diff = DAG.getNode(ISD::ABS, sdl, PtrVT, Diff);
8296+
8297+
Diff = DAG.getNode(ISD::SDIV, sdl, PtrVT, Diff, EltSize);
8298+
SDValue Zero = DAG.getTargetConstant(0, sdl, PtrVT);
8299+
8300+
// If the difference is positive then some elements may alias
8301+
auto CmpVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
8302+
PtrVT);
8303+
SDValue Cmp = DAG.getSetCC(sdl, CmpVT, Diff, Zero, IsWriteAfterRead ? ISD::SETLE : ISD::SETEQ);
8304+
8305+
// Splat the compare result then OR it with a lane mask
8306+
SDValue Splat = DAG.getSplat(IntrinsicVT, sdl, Cmp);
8307+
8308+
SDValue DiffMask;
8309+
// Don't emit an active lane mask if the target doesn't support it
8310+
if (TLI.shouldExpandGetActiveLaneMask(IntrinsicVT, PtrVT)) {
8311+
EVT VecTy = EVT::getVectorVT(*DAG.getContext(), PtrVT,
8312+
IntrinsicVT.getVectorElementCount());
8313+
SDValue DiffSplat = DAG.getSplat(VecTy, sdl, Diff);
8314+
SDValue VectorStep = DAG.getStepVector(sdl, VecTy);
8315+
DiffMask = DAG.getSetCC(sdl, IntrinsicVT, VectorStep,
8316+
DiffSplat, ISD::CondCode::SETULT);
8317+
} else {
8318+
DiffMask = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, sdl, IntrinsicVT, DAG.getTargetConstant(Intrinsic::get_active_lane_mask, sdl, MVT::i64), Zero, Diff);
8319+
}
8320+
SDValue Or = DAG.getNode(ISD::OR, sdl, IntrinsicVT, DiffMask, Splat);
8321+
setValue(&I, Or);
8322+
}
82798323
}
82808324
}
82818325

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,6 +2038,24 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
20382038
return false;
20392039
}
20402040

2041+
bool AArch64TargetLowering::shouldExpandGetAliasLaneMask(EVT VT, EVT PtrVT, unsigned EltSize) const {
2042+
if (!Subtarget->hasSVE2())
2043+
return true;
2044+
2045+
if (PtrVT != MVT::i64)
2046+
return true;
2047+
2048+
if (VT == MVT::v2i1 || VT == MVT::nxv2i1)
2049+
return EltSize != 8;
2050+
if( VT == MVT::v4i1 || VT == MVT::nxv4i1)
2051+
return EltSize != 4;
2052+
if (VT == MVT::v8i1 || VT == MVT::nxv8i1)
2053+
return EltSize != 2;
2054+
if (VT == MVT::v16i1 || VT == MVT::nxv16i1)
2055+
return EltSize != 1;
2056+
return true;
2057+
}
2058+
20412059
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
20422060
const IntrinsicInst *I) const {
20432061
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
@@ -2818,6 +2836,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
28182836
MAKE_CASE(AArch64ISD::LS64_BUILD)
28192837
MAKE_CASE(AArch64ISD::LS64_EXTRACT)
28202838
MAKE_CASE(AArch64ISD::TBL)
2839+
MAKE_CASE(AArch64ISD::WHILEWR)
2840+
MAKE_CASE(AArch64ISD::WHILERW)
28212841
MAKE_CASE(AArch64ISD::FADD_PRED)
28222842
MAKE_CASE(AArch64ISD::FADDA_PRED)
28232843
MAKE_CASE(AArch64ISD::FADDV_PRED)
@@ -6016,6 +6036,16 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
60166036
EVT PtrVT = getPointerTy(DAG.getDataLayout());
60176037
return DAG.getNode(AArch64ISD::THREAD_POINTER, dl, PtrVT);
60186038
}
6039+
case Intrinsic::aarch64_sve_whilewr_b:
6040+
case Intrinsic::aarch64_sve_whilewr_h:
6041+
case Intrinsic::aarch64_sve_whilewr_s:
6042+
case Intrinsic::aarch64_sve_whilewr_d:
6043+
return DAG.getNode(AArch64ISD::WHILEWR, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2));
6044+
case Intrinsic::aarch64_sve_whilerw_b:
6045+
case Intrinsic::aarch64_sve_whilerw_h:
6046+
case Intrinsic::aarch64_sve_whilerw_s:
6047+
case Intrinsic::aarch64_sve_whilerw_d:
6048+
return DAG.getNode(AArch64ISD::WHILERW, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2));
60196049
case Intrinsic::aarch64_neon_abs: {
60206050
EVT Ty = Op.getValueType();
60216051
if (Ty == MVT::i64) {
@@ -6475,16 +6505,39 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
64756505
return DAG.getNode(AArch64ISD::USDOT, dl, Op.getValueType(),
64766506
Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
64776507
}
6508+
case Intrinsic::experimental_get_alias_lane_mask:
64786509
case Intrinsic::get_active_lane_mask: {
6510+
unsigned IntrinsicID = Intrinsic::aarch64_sve_whilelo;
6511+
if (IntNo == Intrinsic::experimental_get_alias_lane_mask) {
6512+
uint64_t EltSize = Op.getOperand(3)->getAsZExtVal();
6513+
bool IsWriteAfterRead = Op.getOperand(4)->getAsZExtVal() == 1;
6514+
switch (EltSize) {
6515+
case 1:
6516+
IntrinsicID = IsWriteAfterRead ? Intrinsic::aarch64_sve_whilewr_b : Intrinsic::aarch64_sve_whilerw_b;
6517+
break;
6518+
case 2:
6519+
IntrinsicID = IsWriteAfterRead ? Intrinsic::aarch64_sve_whilewr_h : Intrinsic::aarch64_sve_whilerw_h;
6520+
break;
6521+
case 4:
6522+
IntrinsicID = IsWriteAfterRead ? Intrinsic::aarch64_sve_whilewr_s : Intrinsic::aarch64_sve_whilerw_s;
6523+
break;
6524+
case 8:
6525+
IntrinsicID = IsWriteAfterRead ? Intrinsic::aarch64_sve_whilewr_d : Intrinsic::aarch64_sve_whilerw_d;
6526+
break;
6527+
default:
6528+
llvm_unreachable("Unexpected element size for get.alias.lane.mask");
6529+
break;
6530+
}
6531+
}
64796532
SDValue ID =
6480-
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64);
6533+
DAG.getTargetConstant(IntrinsicID, dl, MVT::i64);
64816534

64826535
EVT VT = Op.getValueType();
64836536
if (VT.isScalableVector())
64846537
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Op.getOperand(1),
64856538
Op.getOperand(2));
64866539

6487-
// We can use the SVE whilelo instruction to lower this intrinsic by
6540+
// We can use the SVE whilelo/whilewr/whilerw instruction to lower this intrinsic by
64886541
// creating the appropriate sequence of scalable vector operations and
64896542
// then extracting a fixed-width subvector from the scalable vector.
64906543

@@ -19872,7 +19925,9 @@ static bool isPredicateCCSettingOp(SDValue N) {
1987219925
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
1987319926
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
1987419927
// get_active_lane_mask is lowered to a whilelo instruction.
19875-
N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
19928+
N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask ||
19929+
// get_alias_lane_mask is lowered to a whilewr/rw instruction.
19930+
N.getConstantOperandVal(0) == Intrinsic::experimental_get_alias_lane_mask)))
1987619931
return true;
1987719932

1987819933
return false;
@@ -27626,6 +27681,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
2762627681
return;
2762727682
}
2762827683
case Intrinsic::experimental_vector_match:
27684+
case Intrinsic::experimental_get_alias_lane_mask:
2762927685
case Intrinsic::get_active_lane_mask: {
2763027686
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
2763127687
return;

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ enum NodeType : unsigned {
298298
SMAXV,
299299
UMAXV,
300300

301+
// Alias lane masks
302+
WHILEWR,
303+
WHILERW,
304+
301305
SADDV_PRED,
302306
UADDV_PRED,
303307
SMAXV_PRED,
@@ -993,6 +997,8 @@ class AArch64TargetLowering : public TargetLowering {
993997

994998
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
995999

1000+
bool shouldExpandGetAliasLaneMask(EVT VT, EVT PtrVT, unsigned EltSize) const override;
1001+
9961002
bool
9971003
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
9981004

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ def AArch64st1q_scatter : SDNode<"AArch64ISD::SST1Q_PRED", SDT_AArch64_SCATTER_V
140140
// AArch64 SVE/SVE2 - the remaining node definitions
141141
//
142142

143+
// Alias masks
144+
def SDT_AArch64Mask : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<2, 1>, SDTCVecEltisVT<0,i1>]>;
145+
def AArch64whilewr : SDNode<"AArch64ISD::WHILEWR", SDT_AArch64Mask>;
146+
def AArch64whilerw : SDNode<"AArch64ISD::WHILERW", SDT_AArch64Mask>;
147+
143148
// SVE CNT/INC/RDVL
144149
def sve_rdvl_imm : ComplexPattern<i64, 1, "SelectRDVLImm<-32, 31, 16>">;
145150
def sve_cnth_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 16, 8>">;
@@ -3914,9 +3919,9 @@ let Predicates = [HasSVE2_or_SME] in {
39143919
defm WHILEHI_PXX : sve_int_while8_rr<0b101, "whilehi", int_aarch64_sve_whilehi, int_aarch64_sve_whilelo>;
39153920

39163921
// SVE2 pointer conflict compare
3917-
defm WHILEWR_PXX : sve2_int_while_rr<0b0, "whilewr", "int_aarch64_sve_whilewr">;
3918-
defm WHILERW_PXX : sve2_int_while_rr<0b1, "whilerw", "int_aarch64_sve_whilerw">;
3919-
} // End HasSVE2_or_SME
3922+
defm WHILEWR_PXX : sve2_int_while_rr<0b0, "whilewr", AArch64whilewr>;
3923+
defm WHILERW_PXX : sve2_int_while_rr<0b1, "whilerw", AArch64whilerw>;
3924+
} // End HasSVE2orSME
39203925

39213926
let Predicates = [HasSVEAES, HasNonStreamingSVE2_or_SSVE_AES] in {
39223927
// SVE2 crypto destructive binary operations

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5835,16 +5835,16 @@ class sve2_int_while_rr<bits<2> sz8_64, bits<1> rw, string asm,
58355835
let isWhile = 1;
58365836
}
58375837

5838-
multiclass sve2_int_while_rr<bits<1> rw, string asm, string op> {
5838+
multiclass sve2_int_while_rr<bits<1> rw, string asm, SDPatternOperator op> {
58395839
def _B : sve2_int_while_rr<0b00, rw, asm, PPR8>;
58405840
def _H : sve2_int_while_rr<0b01, rw, asm, PPR16>;
58415841
def _S : sve2_int_while_rr<0b10, rw, asm, PPR32>;
58425842
def _D : sve2_int_while_rr<0b11, rw, asm, PPR64>;
58435843

5844-
def : SVE_2_Op_Pat<nxv16i1, !cast<SDPatternOperator>(op # _b), i64, i64, !cast<Instruction>(NAME # _B)>;
5845-
def : SVE_2_Op_Pat<nxv8i1, !cast<SDPatternOperator>(op # _h), i64, i64, !cast<Instruction>(NAME # _H)>;
5846-
def : SVE_2_Op_Pat<nxv4i1, !cast<SDPatternOperator>(op # _s), i64, i64, !cast<Instruction>(NAME # _S)>;
5847-
def : SVE_2_Op_Pat<nxv2i1, !cast<SDPatternOperator>(op # _d), i64, i64, !cast<Instruction>(NAME # _D)>;
5844+
def : SVE_2_Op_Pat<nxv16i1, op, i64, i64, !cast<Instruction>(NAME # _B)>;
5845+
def : SVE_2_Op_Pat<nxv8i1, op, i64, i64, !cast<Instruction>(NAME # _H)>;
5846+
def : SVE_2_Op_Pat<nxv4i1, op, i64, i64, !cast<Instruction>(NAME # _S)>;
5847+
def : SVE_2_Op_Pat<nxv2i1, op, i64, i64, !cast<Instruction>(NAME # _D)>;
58485848
}
58495849

58505850
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)