Skip to content

Commit 8fa824d

Browse files
committed
[ARM] Add predicated add reduction patterns
Given a vecreduce.add(select(p, x, 0)), we can convert that to a predicated vaddv, as the else value for the select is the identity value, a zero. That is what this patch does for the vaddv, vaddva, vaddlv and vaddlva instructions, copying the existing patterns to also handle predication through a select. Differential Revision: https://reviews.llvm.org/D84101
1 parent 89e61e7 commit 8fa824d

File tree

5 files changed

+193
-1101
lines changed

5 files changed

+193
-1101
lines changed

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,8 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
17181718
case ARMISD::VMULLu: return "ARMISD::VMULLu";
17191719
case ARMISD::VADDVs: return "ARMISD::VADDVs";
17201720
case ARMISD::VADDVu: return "ARMISD::VADDVu";
1721+
case ARMISD::VADDVps: return "ARMISD::VADDVps";
1722+
case ARMISD::VADDVpu: return "ARMISD::VADDVpu";
17211723
case ARMISD::VADDLVs: return "ARMISD::VADDLVs";
17221724
case ARMISD::VADDLVu: return "ARMISD::VADDLVu";
17231725
case ARMISD::VADDLVAs: return "ARMISD::VADDLVAs";
@@ -14729,6 +14731,20 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG,
1472914731
return A;
1473014732
return SDValue();
1473114733
};
14734+
auto IsPredVADDV = [&](MVT RetTy, unsigned ExtendCode,
14735+
ArrayRef<MVT> ExtTypes, SDValue &Mask) {
14736+
if (ResVT != RetTy || N0->getOpcode() != ISD::VSELECT ||
14737+
!ISD::isBuildVectorAllZeros(N0->getOperand(2).getNode()))
14738+
return SDValue();
14739+
Mask = N0->getOperand(0);
14740+
SDValue Ext = N0->getOperand(1);
14741+
if (Ext->getOpcode() != ExtendCode)
14742+
return SDValue();
14743+
SDValue A = Ext->getOperand(0);
14744+
if (llvm::any_of(ExtTypes, [&A](MVT Ty) { return A.getValueType() == Ty; }))
14745+
return A;
14746+
return SDValue();
14747+
};
1473214748
auto IsVMLAV = [&](MVT RetTy, unsigned ExtendCode, ArrayRef<MVT> ExtTypes,
1473314749
SDValue &A, SDValue &B) {
1473414750
if (ResVT != RetTy || N0->getOpcode() != ISD::MUL)
@@ -14759,6 +14775,16 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG,
1475914775
if (SDValue A = IsVADDV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v4i32}))
1476014776
return Create64bitNode(ARMISD::VADDLVu, {A});
1476114777

14778+
SDValue Mask;
14779+
if (SDValue A = IsPredVADDV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8}, Mask))
14780+
return DAG.getNode(ARMISD::VADDVps, dl, ResVT, A, Mask);
14781+
if (SDValue A = IsPredVADDV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8}, Mask))
14782+
return DAG.getNode(ARMISD::VADDVpu, dl, ResVT, A, Mask);
14783+
if (SDValue A = IsPredVADDV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v4i32}, Mask))
14784+
return Create64bitNode(ARMISD::VADDLVps, {A, Mask});
14785+
if (SDValue A = IsPredVADDV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v4i32}, Mask))
14786+
return Create64bitNode(ARMISD::VADDLVpu, {A, Mask});
14787+
1476214788
SDValue A, B;
1476314789
if (IsVMLAV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8}, A, B))
1476414790
return DAG.getNode(ARMISD::VMLAVs, dl, ResVT, A, B);

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,16 @@ class VectorType;
219219
// MVE reductions
220220
VADDVs, // sign- or zero-extend the elements of a vector to i32,
221221
VADDVu, // add them all together, and return an i32 of their sum
222+
VADDVps, // Same as VADDV[su] but with a v4i1 predicate mask
223+
VADDVpu,
222224
VADDLVs, // sign- or zero-extend elements to i64 and sum, returning
223225
VADDLVu, // the low and high 32-bit halves of the sum
224-
VADDLVAs, // same as VADDLV[su] but also add an input accumulator
226+
VADDLVAs, // Same as VADDLV[su] but also add an input accumulator
225227
VADDLVAu, // provided as low and high halves
226-
VADDLVps, // same as VADDLVs but with a v4i1 predicate mask
227-
VADDLVpu, // same as VADDLVu but with a v4i1 predicate mask
228-
VADDLVAps, // same as VADDLVps but with a v4i1 predicate mask
229-
VADDLVApu, // same as VADDLVpu but with a v4i1 predicate mask
228+
VADDLVps, // Same as VADDLV[su] but with a v4i1 predicate mask
229+
VADDLVpu,
230+
VADDLVAps, // Same as VADDLVp[su] but with a v4i1 predicate mask
231+
VADDLVApu,
230232
VMLAVs,
231233
VMLAVu,
232234
VMLALVs,

llvm/lib/Target/ARM/ARMInstrMVE.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,8 +684,13 @@ class MVE_VADDV<string iname, string suffix, dag iops, string cstr,
684684
let validForTailPredication = 1;
685685
}
686686

687+
def SDTVecReduceP : SDTypeProfile<1, 2, [ // VADDLVp
688+
SDTCisInt<0>, SDTCisVec<1>, SDTCisVec<2>
689+
]>;
687690
def ARMVADDVs : SDNode<"ARMISD::VADDVs", SDTVecReduce>;
688691
def ARMVADDVu : SDNode<"ARMISD::VADDVu", SDTVecReduce>;
692+
def ARMVADDVps : SDNode<"ARMISD::VADDVps", SDTVecReduceP>;
693+
def ARMVADDVpu : SDNode<"ARMISD::VADDVpu", SDTVecReduceP>;
689694

690695
multiclass MVE_VADDV_A<MVEVectorVTInfo VTI> {
691696
def acc : MVE_VADDV<"vaddva", VTI.Suffix,
@@ -702,20 +707,39 @@ multiclass MVE_VADDV_A<MVEVectorVTInfo VTI> {
702707
if VTI.Unsigned then {
703708
def : Pat<(i32 (vecreduce_add (VTI.Vec MQPR:$vec))),
704709
(i32 (InstN $vec))>;
710+
def : Pat<(i32 (vecreduce_add (VTI.Vec (vselect (VTI.Pred VCCR:$pred),
711+
(VTI.Vec MQPR:$vec),
712+
(VTI.Vec ARMimmAllZerosV))))),
713+
(i32 (InstN $vec, ARMVCCThen, $pred))>;
705714
def : Pat<(i32 (ARMVADDVu (VTI.Vec MQPR:$vec))),
706715
(i32 (InstN $vec))>;
716+
def : Pat<(i32 (ARMVADDVpu (VTI.Vec MQPR:$vec), (VTI.Pred VCCR:$pred))),
717+
(i32 (InstN $vec, ARMVCCThen, $pred))>;
707718
def : Pat<(i32 (add (i32 (vecreduce_add (VTI.Vec MQPR:$vec))),
708719
(i32 tGPREven:$acc))),
709720
(i32 (InstA $acc, $vec))>;
721+
def : Pat<(i32 (add (i32 (vecreduce_add (VTI.Vec (vselect (VTI.Pred VCCR:$pred),
722+
(VTI.Vec MQPR:$vec),
723+
(VTI.Vec ARMimmAllZerosV))))),
724+
(i32 tGPREven:$acc))),
725+
(i32 (InstA $acc, $vec, ARMVCCThen, $pred))>;
710726
def : Pat<(i32 (add (i32 (ARMVADDVu (VTI.Vec MQPR:$vec))),
711727
(i32 tGPREven:$acc))),
712728
(i32 (InstA $acc, $vec))>;
729+
def : Pat<(i32 (add (i32 (ARMVADDVpu (VTI.Vec MQPR:$vec), (VTI.Pred VCCR:$pred))),
730+
(i32 tGPREven:$acc))),
731+
(i32 (InstA $acc, $vec, ARMVCCThen, $pred))>;
713732
} else {
714733
def : Pat<(i32 (ARMVADDVs (VTI.Vec MQPR:$vec))),
715734
(i32 (InstN $vec))>;
716735
def : Pat<(i32 (add (i32 (ARMVADDVs (VTI.Vec MQPR:$vec))),
717736
(i32 tGPREven:$acc))),
718737
(i32 (InstA $acc, $vec))>;
738+
def : Pat<(i32 (ARMVADDVps (VTI.Vec MQPR:$vec), (VTI.Pred VCCR:$pred))),
739+
(i32 (InstN $vec, ARMVCCThen, $pred))>;
740+
def : Pat<(i32 (add (i32 (ARMVADDVps (VTI.Vec MQPR:$vec), (VTI.Pred VCCR:$pred))),
741+
(i32 tGPREven:$acc))),
742+
(i32 (InstA $acc, $vec, ARMVCCThen, $pred))>;
719743
}
720744

721745
def : Pat<(i32 (int_arm_mve_addv_predicated (VTI.Vec MQPR:$vec),

0 commit comments

Comments
 (0)