Skip to content

Commit 927b8a0

Browse files
authored
[AArch64][GlobalISel] Combine vecreduce(ext) to {U/S}ADDLV (#75832)
1 parent ba131b7 commit 927b8a0

File tree

7 files changed

+1410
-889
lines changed

7 files changed

+1410
-889
lines changed

llvm/lib/Target/AArch64/AArch64Combine.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,22 @@ def ext_addv_to_udot_addv : GICombineRule<
4444
>;
4545
}
4646

47+
def ext_uaddv_to_uaddlv_matchinfo : GIDefMatchData<"std::pair<Register, bool>">;
48+
def ext_uaddv_to_uaddlv : GICombineRule<
49+
(defs root:$root, ext_uaddv_to_uaddlv_matchinfo:$matchinfo),
50+
(match (wip_match_opcode G_VECREDUCE_ADD):$root,
51+
[{ return matchExtUaddvToUaddlv(*${root}, MRI, ${matchinfo}); }]),
52+
(apply [{ applyExtUaddvToUaddlv(*${root}, MRI, B, Observer, ${matchinfo}); }])
53+
>;
54+
4755
def AArch64PreLegalizerCombiner: GICombiner<
4856
"AArch64PreLegalizerCombinerImpl", [all_combines,
4957
fconstant_to_constant,
5058
icmp_redundant_trunc,
5159
fold_global_offset,
5260
shuffle_to_extract,
53-
ext_addv_to_udot_addv]> {
61+
ext_addv_to_udot_addv,
62+
ext_uaddv_to_uaddlv]> {
5463
let CombineAllMethodName = "tryCombineAllImpl";
5564
}
5665

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,6 +2464,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
24642464
MAKE_CASE(AArch64ISD::SADDV)
24652465
MAKE_CASE(AArch64ISD::UADDV)
24662466
MAKE_CASE(AArch64ISD::UADDLV)
2467+
MAKE_CASE(AArch64ISD::SADDLV)
24672468
MAKE_CASE(AArch64ISD::SDOT)
24682469
MAKE_CASE(AArch64ISD::UDOT)
24692470
MAKE_CASE(AArch64ISD::SMINV)

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ enum NodeType : unsigned {
248248

249249
// Unsigned sum Long across Vector
250250
UADDLV,
251+
SADDLV,
251252

252253
// Add Pairwise of two vectors
253254
ADDP,

llvm/lib/Target/AArch64/AArch64InstrGISel.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,18 @@ def G_SMULL : AArch64GenericInstruction {
227227
let hasSideEffects = 0;
228228
}
229229

230+
def G_UADDLV : AArch64GenericInstruction {
231+
let OutOperandList = (outs type0:$dst);
232+
let InOperandList = (ins type0:$src1);
233+
let hasSideEffects = 0;
234+
}
235+
236+
def G_SADDLV : AArch64GenericInstruction {
237+
let OutOperandList = (outs type0:$dst);
238+
let InOperandList = (ins type0:$src1);
239+
let hasSideEffects = 0;
240+
}
241+
230242
def G_UDOT : AArch64GenericInstruction {
231243
let OutOperandList = (outs type0:$dst);
232244
let InOperandList = (ins type0:$src1, type0:$src2, type0:$src3);
@@ -282,6 +294,9 @@ def : GINodeEquiv<G_BSP, AArch64bsp>;
282294
def : GINodeEquiv<G_UMULL, AArch64umull>;
283295
def : GINodeEquiv<G_SMULL, AArch64smull>;
284296

297+
def : GINodeEquiv<G_SADDLV, AArch64saddlv>;
298+
def : GINodeEquiv<G_UADDLV, AArch64uaddlv>;
299+
285300
def : GINodeEquiv<G_UDOT, AArch64udot>;
286301
def : GINodeEquiv<G_SDOT, AArch64sdot>;
287302

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ def AArch64uminv : SDNode<"AArch64ISD::UMINV", SDT_AArch64UnaryVec>;
799799
def AArch64smaxv : SDNode<"AArch64ISD::SMAXV", SDT_AArch64UnaryVec>;
800800
def AArch64umaxv : SDNode<"AArch64ISD::UMAXV", SDT_AArch64UnaryVec>;
801801
def AArch64uaddlv : SDNode<"AArch64ISD::UADDLV", SDT_AArch64uaddlp>;
802+
def AArch64saddlv : SDNode<"AArch64ISD::SADDLV", SDT_AArch64uaddlp>;
802803

803804
def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs),
804805
[(abdu node:$lhs, node:$rhs),
@@ -6680,17 +6681,25 @@ def : Pat<(v4i32 (AArch64uaddlv (v8i16 (AArch64uaddlp (v16i8 V128:$op))))),
66806681
def : Pat<(v4i32 (AArch64uaddlv (v4i16 (AArch64uaddlp (v8i8 V64:$op))))),
66816682
(v4i32 (SUBREG_TO_REG (i64 0), (UADDLVv8i8v V64:$op), hsub))>;
66826683

6683-
def : Pat<(v4i32 (AArch64uaddlv (v8i8 V64:$Rn))),
6684-
(v4i32 (SUBREG_TO_REG (i64 0), (UADDLVv8i8v V64:$Rn), hsub))>;
6684+
multiclass SIMDAcrossLaneLongReductionIntrinsic<string Opc, SDPatternOperator addlv> {
6685+
def : Pat<(v4i32 (addlv (v8i8 V64:$Rn))),
6686+
(v4i32 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v8i8v") V64:$Rn), hsub))>;
66856687

6686-
def : Pat<(v4i32 (AArch64uaddlv (v4i16 V64:$Rn))),
6687-
(v4i32 (SUBREG_TO_REG (i64 0), (UADDLVv4i16v V64:$Rn), ssub))>;
6688+
def : Pat<(v4i32 (addlv (v4i16 V64:$Rn))),
6689+
(v4i32 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v4i16v") V64:$Rn), ssub))>;
66886690

6689-
def : Pat<(v4i32 (AArch64uaddlv (v16i8 V128:$Rn))),
6690-
(v4i32 (SUBREG_TO_REG (i64 0), (UADDLVv16i8v V128:$Rn), hsub))>;
6691+
def : Pat<(v4i32 (addlv (v16i8 V128:$Rn))),
6692+
(v4i32 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v16i8v") V128:$Rn), hsub))>;
66916693

6692-
def : Pat<(v4i32 (AArch64uaddlv (v8i16 V128:$Rn))),
6693-
(v4i32 (SUBREG_TO_REG (i64 0), (UADDLVv8i16v V128:$Rn), ssub))>;
6694+
def : Pat<(v4i32 (addlv (v8i16 V128:$Rn))),
6695+
(v4i32 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v8i16v") V128:$Rn), ssub))>;
6696+
6697+
def : Pat<(v2i64 (addlv (v4i32 V128:$Rn))),
6698+
(v2i64 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v4i32v") V128:$Rn), dsub))>;
6699+
}
6700+
6701+
defm : SIMDAcrossLaneLongReductionIntrinsic<"UADDLV", AArch64uaddlv>;
6702+
defm : SIMDAcrossLaneLongReductionIntrinsic<"SADDLV", AArch64saddlv>;
66946703

66956704
// Patterns for across-vector intrinsics, that have a node equivalent, that
66966705
// returns a vector (with only the low lane defined) instead of a scalar.

llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,150 @@ void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
410410
MI.eraseFromParent();
411411
}
412412

413+
// Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
414+
// Ensure that the type coming from the extend instruction is the right size
415+
bool matchExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
416+
std::pair<Register, bool> &MatchInfo) {
417+
assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
418+
"Expected G_VECREDUCE_ADD Opcode");
419+
420+
// Check if the last instruction is an extend
421+
MachineInstr *ExtMI = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
422+
auto ExtOpc = ExtMI->getOpcode();
423+
424+
if (ExtOpc == TargetOpcode::G_ZEXT)
425+
std::get<1>(MatchInfo) = 0;
426+
else if (ExtOpc == TargetOpcode::G_SEXT)
427+
std::get<1>(MatchInfo) = 1;
428+
else
429+
return false;
430+
431+
// Check if the source register is a valid type
432+
Register ExtSrcReg = ExtMI->getOperand(1).getReg();
433+
LLT ExtSrcTy = MRI.getType(ExtSrcReg);
434+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
435+
if ((DstTy.getScalarSizeInBits() == 16 &&
436+
ExtSrcTy.getNumElements() % 8 == 0 && ExtSrcTy.getNumElements() < 256) ||
437+
(DstTy.getScalarSizeInBits() == 32 &&
438+
ExtSrcTy.getNumElements() % 4 == 0) ||
439+
(DstTy.getScalarSizeInBits() == 64 &&
440+
ExtSrcTy.getNumElements() % 4 == 0)) {
441+
std::get<0>(MatchInfo) = ExtSrcReg;
442+
return true;
443+
}
444+
return false;
445+
}
446+
447+
void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
448+
MachineIRBuilder &B, GISelChangeObserver &Observer,
449+
std::pair<Register, bool> &MatchInfo) {
450+
assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
451+
"Expected G_VECREDUCE_ADD Opcode");
452+
453+
unsigned Opc = std::get<1>(MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
454+
Register SrcReg = std::get<0>(MatchInfo);
455+
Register DstReg = MI.getOperand(0).getReg();
456+
LLT SrcTy = MRI.getType(SrcReg);
457+
LLT DstTy = MRI.getType(DstReg);
458+
459+
// If SrcTy has more elements than expected, split them into multiple
460+
// insructions and sum the results
461+
LLT MainTy;
462+
SmallVector<Register, 1> WorkingRegisters;
463+
unsigned SrcScalSize = SrcTy.getScalarSizeInBits();
464+
unsigned SrcNumElem = SrcTy.getNumElements();
465+
if ((SrcScalSize == 8 && SrcNumElem > 16) ||
466+
(SrcScalSize == 16 && SrcNumElem > 8) ||
467+
(SrcScalSize == 32 && SrcNumElem > 4)) {
468+
469+
LLT LeftoverTy;
470+
SmallVector<Register, 4> LeftoverRegs;
471+
if (SrcScalSize == 8)
472+
MainTy = LLT::fixed_vector(16, 8);
473+
else if (SrcScalSize == 16)
474+
MainTy = LLT::fixed_vector(8, 16);
475+
else if (SrcScalSize == 32)
476+
MainTy = LLT::fixed_vector(4, 32);
477+
else
478+
llvm_unreachable("Source's Scalar Size not supported");
479+
480+
// Extract the parts and put each extracted sources through U/SADDLV and put
481+
// the values inside a small vec
482+
extractParts(SrcReg, SrcTy, MainTy, LeftoverTy, WorkingRegisters,
483+
LeftoverRegs, B, MRI);
484+
for (unsigned I = 0; I < LeftoverRegs.size(); I++) {
485+
WorkingRegisters.push_back(LeftoverRegs[I]);
486+
}
487+
} else {
488+
WorkingRegisters.push_back(SrcReg);
489+
MainTy = SrcTy;
490+
}
491+
492+
unsigned MidScalarSize = MainTy.getScalarSizeInBits() * 2;
493+
LLT MidScalarLLT = LLT::scalar(MidScalarSize);
494+
Register zeroReg = B.buildConstant(LLT::scalar(64), 0).getReg(0);
495+
for (unsigned I = 0; I < WorkingRegisters.size(); I++) {
496+
// If the number of elements is too small to build an instruction, extend
497+
// its size before applying addlv
498+
LLT WorkingRegTy = MRI.getType(WorkingRegisters[I]);
499+
if ((WorkingRegTy.getScalarSizeInBits() == 8) &&
500+
(WorkingRegTy.getNumElements() == 4)) {
501+
WorkingRegisters[I] =
502+
B.buildInstr(std::get<1>(MatchInfo) ? TargetOpcode::G_SEXT
503+
: TargetOpcode::G_ZEXT,
504+
{LLT::fixed_vector(4, 16)}, {WorkingRegisters[I]})
505+
.getReg(0);
506+
}
507+
508+
// Generate the {U/S}ADDLV instruction, whose output is always double of the
509+
// Src's Scalar size
510+
LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(4, 32)
511+
: LLT::fixed_vector(2, 64);
512+
Register addlvReg =
513+
B.buildInstr(Opc, {addlvTy}, {WorkingRegisters[I]}).getReg(0);
514+
515+
// The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or
516+
// v2i64 register.
517+
// i16, i32 results uses v4i32 registers
518+
// i64 results uses v2i64 registers
519+
// Therefore we have to extract/truncate the the value to the right type
520+
if (MidScalarSize == 32 || MidScalarSize == 64) {
521+
WorkingRegisters[I] = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT,
522+
{MidScalarLLT}, {addlvReg, zeroReg})
523+
.getReg(0);
524+
} else {
525+
Register extractReg = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT,
526+
{LLT::scalar(32)}, {addlvReg, zeroReg})
527+
.getReg(0);
528+
WorkingRegisters[I] =
529+
B.buildTrunc({MidScalarLLT}, {extractReg}).getReg(0);
530+
}
531+
}
532+
533+
Register outReg;
534+
if (WorkingRegisters.size() > 1) {
535+
outReg = B.buildAdd(MidScalarLLT, WorkingRegisters[0], WorkingRegisters[1])
536+
.getReg(0);
537+
for (unsigned I = 2; I < WorkingRegisters.size(); I++) {
538+
outReg = B.buildAdd(MidScalarLLT, outReg, WorkingRegisters[I]).getReg(0);
539+
}
540+
} else {
541+
outReg = WorkingRegisters[0];
542+
}
543+
544+
if (DstTy.getScalarSizeInBits() > MidScalarSize) {
545+
// Handle the scalar value if the DstTy's Scalar Size is more than double
546+
// Src's ScalarType
547+
B.buildInstr(std::get<1>(MatchInfo) ? TargetOpcode::G_SEXT
548+
: TargetOpcode::G_ZEXT,
549+
{DstReg}, {outReg});
550+
} else {
551+
B.buildCopy(DstReg, outReg);
552+
}
553+
554+
MI.eraseFromParent();
555+
}
556+
413557
bool tryToSimplifyUADDO(MachineInstr &MI, MachineIRBuilder &B,
414558
CombinerHelper &Helper, GISelChangeObserver &Observer) {
415559
// Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if

0 commit comments

Comments
 (0)