Skip to content

Commit 0ed325d

Browse files
committed
[AArch64][GlobalISel] Combine vecreduce(ext) to {U/S}ADDLV
Combines vecreduce_add(ext) to uaddlv instructions
1 parent 91ad317 commit 0ed325d

File tree

6 files changed

+1417
-889
lines changed

6 files changed

+1417
-889
lines changed

llvm/lib/Target/AArch64/AArch64Combine.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,22 @@ def ext_addv_to_udot_addv : GICombineRule<
4444
>;
4545
}
4646

47+
def ext_uaddv_to_uaddlv_matchinfo : GIDefMatchData<"std::pair<Register, bool>">;
48+
def ext_uaddv_to_uaddlv : GICombineRule<
49+
(defs root:$root, ext_uaddv_to_uaddlv_matchinfo:$matchinfo),
50+
(match (wip_match_opcode G_VECREDUCE_ADD):$root,
51+
[{ return matchExtUaddvToUaddlv(*${root}, MRI, ${matchinfo}); }]),
52+
(apply [{ applyExtUaddvToUaddlv(*${root}, MRI, B, Observer, ${matchinfo}); }])
53+
>;
54+
4755
def AArch64PreLegalizerCombiner: GICombiner<
4856
"AArch64PreLegalizerCombinerImpl", [all_combines,
4957
fconstant_to_constant,
5058
icmp_redundant_trunc,
5159
fold_global_offset,
5260
shuffle_to_extract,
53-
ext_addv_to_udot_addv]> {
61+
ext_addv_to_udot_addv,
62+
ext_uaddv_to_uaddlv]> {
5463
let CombineAllMethodName = "tryCombineAllImpl";
5564
}
5665

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ enum NodeType : unsigned {
248248

249249
// Unsigned sum Long across Vector
250250
UADDLV,
251+
SADDLV,
251252

252253
// Add Pairwise of two vectors
253254
ADDP,

llvm/lib/Target/AArch64/AArch64InstrGISel.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,18 @@ def G_SMULL : AArch64GenericInstruction {
227227
let hasSideEffects = 0;
228228
}
229229

230+
def G_UADDLV : AArch64GenericInstruction {
231+
let OutOperandList = (outs type0:$dst);
232+
let InOperandList = (ins type0:$src1);
233+
let hasSideEffects = 0;
234+
}
235+
236+
def G_SADDLV : AArch64GenericInstruction {
237+
let OutOperandList = (outs type0:$dst);
238+
let InOperandList = (ins type0:$src1);
239+
let hasSideEffects = 0;
240+
}
241+
230242
def G_UDOT : AArch64GenericInstruction {
231243
let OutOperandList = (outs type0:$dst);
232244
let InOperandList = (ins type0:$src1, type0:$src2, type0:$src3);
@@ -282,6 +294,9 @@ def : GINodeEquiv<G_BSP, AArch64bsp>;
282294
def : GINodeEquiv<G_UMULL, AArch64umull>;
283295
def : GINodeEquiv<G_SMULL, AArch64smull>;
284296

297+
def : GINodeEquiv<G_SADDLV, AArch64saddlv>;
298+
def : GINodeEquiv<G_UADDLV, AArch64uaddlv>;
299+
285300
def : GINodeEquiv<G_UDOT, AArch64udot>;
286301
def : GINodeEquiv<G_SDOT, AArch64sdot>;
287302

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ def AArch64uminv : SDNode<"AArch64ISD::UMINV", SDT_AArch64UnaryVec>;
799799
def AArch64smaxv : SDNode<"AArch64ISD::SMAXV", SDT_AArch64UnaryVec>;
800800
def AArch64umaxv : SDNode<"AArch64ISD::UMAXV", SDT_AArch64UnaryVec>;
801801
def AArch64uaddlv : SDNode<"AArch64ISD::UADDLV", SDT_AArch64uaddlp>;
802+
def AArch64saddlv : SDNode<"AArch64ISD::SADDLV", SDT_AArch64uaddlp>;
802803

803804
def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs),
804805
[(abdu node:$lhs, node:$rhs),
@@ -6680,17 +6681,25 @@ def : Pat<(v4i32 (AArch64uaddlv (v8i16 (AArch64uaddlp (v16i8 V128:$op))))),
66806681
def : Pat<(v4i32 (AArch64uaddlv (v4i16 (AArch64uaddlp (v8i8 V64:$op))))),
66816682
(v4i32 (SUBREG_TO_REG (i64 0), (UADDLVv8i8v V64:$op), hsub))>;
66826683

6683-
def : Pat<(v4i32 (AArch64uaddlv (v8i8 V64:$Rn))),
6684-
(v4i32 (SUBREG_TO_REG (i64 0), (UADDLVv8i8v V64:$Rn), hsub))>;
6684+
multiclass SIMDAcrossLaneLongReductionIntrinsic<string Opc, SDPatternOperator addlv> {
6685+
def : Pat<(v4i32 (addlv (v8i8 V64:$Rn))),
6686+
(v4i32 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v8i8v") V64:$Rn), hsub))>;
66856687

6686-
def : Pat<(v4i32 (AArch64uaddlv (v4i16 V64:$Rn))),
6687-
(v4i32 (SUBREG_TO_REG (i64 0), (UADDLVv4i16v V64:$Rn), ssub))>;
6688+
def : Pat<(v4i32 (addlv (v4i16 V64:$Rn))),
6689+
(v4i32 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v4i16v") V64:$Rn), ssub))>;
66886690

6689-
def : Pat<(v4i32 (AArch64uaddlv (v16i8 V128:$Rn))),
6690-
(v4i32 (SUBREG_TO_REG (i64 0), (UADDLVv16i8v V128:$Rn), hsub))>;
6691+
def : Pat<(v4i32 (addlv (v16i8 V128:$Rn))),
6692+
(v4i32 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v16i8v") V128:$Rn), hsub))>;
66916693

6692-
def : Pat<(v4i32 (AArch64uaddlv (v8i16 V128:$Rn))),
6693-
(v4i32 (SUBREG_TO_REG (i64 0), (UADDLVv8i16v V128:$Rn), ssub))>;
6694+
def : Pat<(v4i32 (addlv (v8i16 V128:$Rn))),
6695+
(v4i32 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v8i16v") V128:$Rn), ssub))>;
6696+
6697+
def : Pat<(v2i64 (addlv (v4i32 V128:$Rn))),
6698+
(v2i64 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v4i32v") V128:$Rn), dsub))>;
6699+
}
6700+
6701+
defm : SIMDAcrossLaneLongReductionIntrinsic<"UADDLV", AArch64uaddlv>;
6702+
defm : SIMDAcrossLaneLongReductionIntrinsic<"SADDLV", AArch64saddlv>;
66946703

66956704
// Patterns for across-vector intrinsics, that have a node equivalent, that
66966705
// returns a vector (with only the low lane defined) instead of a scalar.

llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,158 @@ void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
411411
MI.eraseFromParent();
412412
}
413413

414+
// Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
415+
// Ensure that the type coming from the extend instruction is the right size
416+
bool matchExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
417+
std::pair<Register, bool> &MatchInfo) {
418+
assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
419+
"Expected G_VECREDUCE_ADD Opcode");
420+
421+
// Check if the last instruction is an extend
422+
MachineInstr *ExtMI = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
423+
auto ExtOpc = ExtMI->getOpcode();
424+
425+
if (ExtOpc == TargetOpcode::G_ZEXT)
426+
std::get<1>(MatchInfo) = 0;
427+
else if (ExtOpc == TargetOpcode::G_SEXT)
428+
std::get<1>(MatchInfo) = 1;
429+
else
430+
return false;
431+
432+
// Check if the source register is a valid type
433+
Register ExtSrcReg = ExtMI->getOperand(1).getReg();
434+
LLT ExtSrcTy = MRI.getType(ExtSrcReg);
435+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
436+
if ((DstTy.getScalarSizeInBits() == 16 &&
437+
ExtSrcTy.getNumElements() % 8 == 0) ||
438+
(DstTy.getScalarSizeInBits() == 32 &&
439+
ExtSrcTy.getNumElements() % 4 == 0) ||
440+
(DstTy.getScalarSizeInBits() == 64 &&
441+
ExtSrcTy.getNumElements() % 4 == 0)) {
442+
std::get<0>(MatchInfo) = ExtSrcReg;
443+
return true;
444+
}
445+
return false;
446+
}
447+
448+
void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
449+
MachineIRBuilder &B, GISelChangeObserver &Observer,
450+
std::pair<Register, bool> &MatchInfo) {
451+
assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
452+
"Expected G_VECREDUCE_ADD Opcode");
453+
454+
unsigned Opc = std::get<1>(MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
455+
Register SrcReg = std::get<0>(MatchInfo);
456+
Register DstReg = MI.getOperand(0).getReg();
457+
LLT SrcTy = MRI.getType(SrcReg);
458+
LLT DstTy = MRI.getType(DstReg);
459+
460+
// If SrcTy has more elements than expected, split them into multiple
461+
// insructions and sum the results
462+
LLT MainTy;
463+
SmallVector<Register, 1> WorkingRegisters;
464+
unsigned SrcScalSize = SrcTy.getScalarSizeInBits();
465+
unsigned SrcNumElem = SrcTy.getNumElements();
466+
if ((SrcScalSize == 8 && SrcNumElem > 16) ||
467+
(SrcScalSize == 16 && SrcNumElem > 8) ||
468+
(SrcScalSize == 32 && SrcNumElem > 4)) {
469+
470+
LLT LeftoverTy;
471+
SmallVector<Register, 4> LeftoverRegs;
472+
if (SrcScalSize == 8)
473+
MainTy = LLT::fixed_vector(16, 8);
474+
else if (SrcScalSize == 16)
475+
MainTy = LLT::fixed_vector(8, 16);
476+
else if (SrcScalSize == 32)
477+
MainTy = LLT::fixed_vector(4, 32);
478+
else
479+
llvm_unreachable("Source's Scalar Size not supported");
480+
481+
// Extract the parts and put each extracted sources through U/SADDLV and put
482+
// the values inside a small vec
483+
extractParts(SrcReg, SrcTy, MainTy, LeftoverTy, WorkingRegisters,
484+
LeftoverRegs, B, MRI);
485+
for (unsigned I = 0; I < LeftoverRegs.size(); I++) {
486+
WorkingRegisters.push_back(LeftoverRegs[I]);
487+
}
488+
} else {
489+
WorkingRegisters.push_back(SrcReg);
490+
MainTy = SrcTy;
491+
}
492+
493+
unsigned MidScalarSize = MainTy.getScalarSizeInBits() * 2;
494+
LLT MidScalarLLT = LLT::scalar(MidScalarSize);
495+
Register zeroReg =
496+
B.buildConstant(LLT::scalar(64), 0)->getOperand(0).getReg();
497+
for (unsigned I = 0; I < WorkingRegisters.size(); I++) {
498+
// If the number of elements is too small to build an instruction, extend
499+
// its size before applying addlv
500+
LLT WorkingRegTy = MRI.getType(WorkingRegisters[I]);
501+
if ((WorkingRegTy.getScalarSizeInBits() == 8) &&
502+
(WorkingRegTy.getNumElements() == 4)) {
503+
WorkingRegisters[I] =
504+
B.buildInstr(std::get<1>(MatchInfo) ? TargetOpcode::G_SEXT
505+
: TargetOpcode::G_ZEXT,
506+
{LLT::fixed_vector(4, 16)}, {WorkingRegisters[I]})
507+
->getOperand(0)
508+
.getReg();
509+
}
510+
511+
// Generate the {U/S}ADDLV instruction, whose output is always double of the
512+
// Src's Scalar size
513+
LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(4, 32)
514+
: LLT::fixed_vector(2, 64);
515+
Register addlvReg = B.buildInstr(Opc, {addlvTy}, {WorkingRegisters[I]})
516+
->getOperand(0)
517+
.getReg();
518+
519+
// The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or
520+
// v2i64 register.
521+
// i16, i32 results uses v4i32 registers
522+
// i64 results uses v2i64 registers
523+
// Therefore we have to extract/truncate the the value to the right type
524+
if (MidScalarSize == 32 || MidScalarSize == 64) {
525+
WorkingRegisters[I] = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT,
526+
{MidScalarLLT}, {addlvReg, zeroReg})
527+
->getOperand(0)
528+
.getReg();
529+
} else {
530+
Register extractReg = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT,
531+
{LLT::scalar(32)}, {addlvReg, zeroReg})
532+
->getOperand(0)
533+
.getReg();
534+
WorkingRegisters[I] =
535+
B.buildTrunc({MidScalarLLT}, {extractReg})->getOperand(0).getReg();
536+
}
537+
}
538+
539+
Register outReg;
540+
if (WorkingRegisters.size() > 1) {
541+
outReg = B.buildAdd(MidScalarLLT, WorkingRegisters[0], WorkingRegisters[1])
542+
->getOperand(0)
543+
.getReg();
544+
for (unsigned I = 2; I < WorkingRegisters.size(); I++) {
545+
outReg = B.buildAdd(MidScalarLLT, outReg, WorkingRegisters[I])
546+
->getOperand(0)
547+
.getReg();
548+
}
549+
} else {
550+
outReg = WorkingRegisters[0];
551+
}
552+
553+
if (DstTy.getScalarSizeInBits() > MidScalarSize) {
554+
// Handle the scalar value if the DstTy's Scalar Size is more than double
555+
// Src's ScalarType
556+
B.buildInstr(std::get<1>(MatchInfo) ? TargetOpcode::G_SEXT
557+
: TargetOpcode::G_ZEXT,
558+
{DstReg}, {outReg});
559+
} else {
560+
B.buildCopy(DstReg, outReg);
561+
}
562+
563+
MI.eraseFromParent();
564+
}
565+
414566
bool tryToSimplifyUADDO(MachineInstr &MI, MachineIRBuilder &B,
415567
CombinerHelper &Helper, GISelChangeObserver &Observer) {
416568
// Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if

0 commit comments

Comments
 (0)