Skip to content

Commit b4017d8

Browse files
authored
[AArch64][GlobalISel] Improve MULL generation (#112405)
This splits the existing post-legalize lowering of vector umull/smull into two parts - one to perform the optimization of mul(ext,ext) -> mull and one to perform the v2i64 mul scalarization. The mull part is moved to post legalizer combine and has been taught a few extra tricks from SDAG, using known bits to convert mul(sext, zext) or mul(zext, zero-upper-bits) into umull. This can be important to prevent v2i64 scalarization of muls.
1 parent 1ee8fe8 commit b4017d8

File tree

6 files changed

+480
-946
lines changed

6 files changed

+480
-946
lines changed

llvm/lib/Target/AArch64/AArch64Combine.td

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,19 @@ def mul_const : GICombineRule<
217217
(apply [{ applyAArch64MulConstCombine(*${root}, MRI, B, ${matchinfo}); }])
218218
>;
219219

220-
def lower_mull : GICombineRule<
221-
(defs root:$root),
222-
(match (wip_match_opcode G_MUL):$root,
223-
[{ return matchExtMulToMULL(*${root}, MRI); }]),
224-
(apply [{ applyExtMulToMULL(*${root}, MRI, B, Observer); }])
220+
def mull_matchdata : GIDefMatchData<"std::tuple<bool, Register, Register>">;
221+
def extmultomull : GICombineRule<
222+
(defs root:$root, mull_matchdata:$matchinfo),
223+
(match (G_MUL $dst, $src1, $src2):$root,
224+
[{ return matchExtMulToMULL(*${root}, MRI, VT, ${matchinfo}); }]),
225+
(apply [{ applyExtMulToMULL(*${root}, MRI, B, Observer, ${matchinfo}); }])
226+
>;
227+
228+
def lower_mulv2s64 : GICombineRule<
229+
(defs root:$root, mull_matchdata:$matchinfo),
230+
(match (G_MUL $dst, $src1, $src2):$root,
231+
[{ return matchMulv2s64(*${root}, MRI); }]),
232+
(apply [{ applyMulv2s64(*${root}, MRI, B, Observer); }])
225233
>;
226234

227235
def build_vector_to_dup : GICombineRule<
@@ -316,7 +324,7 @@ def AArch64PostLegalizerLowering
316324
icmp_lowering, build_vector_lowering,
317325
lower_vector_fcmp, form_truncstore,
318326
vector_sext_inreg_to_shift,
319-
unmerge_ext_to_unmerge, lower_mull,
327+
unmerge_ext_to_unmerge, lower_mulv2s64,
320328
vector_unmerge_lowering, insertelt_nonconst]> {
321329
}
322330

@@ -339,5 +347,5 @@ def AArch64PostLegalizerCombiner
339347
select_to_minmax, or_to_bsp, combine_concat_vector,
340348
commute_constant_to_rhs,
341349
push_freeze_to_prevent_poison_from_propagating,
342-
combine_mul_cmlt, combine_use_vector_truncate]> {
350+
combine_mul_cmlt, combine_use_vector_truncate, extmultomull]> {
343351
}

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,122 @@ void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI,
438438
MI.eraseFromParent();
439439
}
440440

441+
// Match mul({z/s}ext , {z/s}ext) => {u/s}mull
442+
bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
443+
GISelValueTracking *KB,
444+
std::tuple<bool, Register, Register> &MatchInfo) {
445+
// Get the instructions that defined the source operand
446+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
447+
MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
448+
MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
449+
unsigned I1Opc = I1->getOpcode();
450+
unsigned I2Opc = I2->getOpcode();
451+
unsigned EltSize = DstTy.getScalarSizeInBits();
452+
453+
if (!DstTy.isVector() || I1->getNumOperands() < 2 || I2->getNumOperands() < 2)
454+
return false;
455+
456+
auto IsAtLeastDoubleExtend = [&](Register R) {
457+
LLT Ty = MRI.getType(R);
458+
return EltSize >= Ty.getScalarSizeInBits() * 2;
459+
};
460+
461+
// If the source operands were EXTENDED before, then {U/S}MULL can be used
462+
bool IsZExt1 =
463+
I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_ANYEXT;
464+
bool IsZExt2 =
465+
I2Opc == TargetOpcode::G_ZEXT || I2Opc == TargetOpcode::G_ANYEXT;
466+
if (IsZExt1 && IsZExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
467+
IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
468+
get<0>(MatchInfo) = true;
469+
get<1>(MatchInfo) = I1->getOperand(1).getReg();
470+
get<2>(MatchInfo) = I2->getOperand(1).getReg();
471+
return true;
472+
}
473+
474+
bool IsSExt1 =
475+
I1Opc == TargetOpcode::G_SEXT || I1Opc == TargetOpcode::G_ANYEXT;
476+
bool IsSExt2 =
477+
I2Opc == TargetOpcode::G_SEXT || I2Opc == TargetOpcode::G_ANYEXT;
478+
if (IsSExt1 && IsSExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
479+
IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
480+
get<0>(MatchInfo) = false;
481+
get<1>(MatchInfo) = I1->getOperand(1).getReg();
482+
get<2>(MatchInfo) = I2->getOperand(1).getReg();
483+
return true;
484+
}
485+
486+
// Select UMULL if we can replace the other operand with an extend.
487+
APInt Mask = APInt::getHighBitsSet(EltSize, EltSize / 2);
488+
if (KB && (IsZExt1 || IsZExt2) &&
489+
IsAtLeastDoubleExtend(IsZExt1 ? I1->getOperand(1).getReg()
490+
: I2->getOperand(1).getReg())) {
491+
Register ZExtOp =
492+
IsZExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
493+
if (KB->maskedValueIsZero(ZExtOp, Mask)) {
494+
get<0>(MatchInfo) = true;
495+
get<1>(MatchInfo) = IsZExt1 ? I1->getOperand(1).getReg() : ZExtOp;
496+
get<2>(MatchInfo) = IsZExt1 ? ZExtOp : I2->getOperand(1).getReg();
497+
return true;
498+
}
499+
} else if (KB && DstTy == LLT::fixed_vector(2, 64) &&
500+
KB->maskedValueIsZero(MI.getOperand(1).getReg(), Mask) &&
501+
KB->maskedValueIsZero(MI.getOperand(2).getReg(), Mask)) {
502+
get<0>(MatchInfo) = true;
503+
get<1>(MatchInfo) = MI.getOperand(1).getReg();
504+
get<2>(MatchInfo) = MI.getOperand(2).getReg();
505+
return true;
506+
}
507+
508+
if (KB && (IsSExt1 || IsSExt2) &&
509+
IsAtLeastDoubleExtend(IsSExt1 ? I1->getOperand(1).getReg()
510+
: I2->getOperand(1).getReg())) {
511+
Register SExtOp =
512+
IsSExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
513+
if (KB->computeNumSignBits(SExtOp) > EltSize / 2) {
514+
get<0>(MatchInfo) = false;
515+
get<1>(MatchInfo) = IsSExt1 ? I1->getOperand(1).getReg() : SExtOp;
516+
get<2>(MatchInfo) = IsSExt1 ? SExtOp : I2->getOperand(1).getReg();
517+
return true;
518+
}
519+
} else if (KB && DstTy == LLT::fixed_vector(2, 64) &&
520+
KB->computeNumSignBits(MI.getOperand(1).getReg()) > EltSize / 2 &&
521+
KB->computeNumSignBits(MI.getOperand(2).getReg()) > EltSize / 2) {
522+
get<0>(MatchInfo) = false;
523+
get<1>(MatchInfo) = MI.getOperand(1).getReg();
524+
get<2>(MatchInfo) = MI.getOperand(2).getReg();
525+
return true;
526+
}
527+
528+
return false;
529+
}
530+
531+
void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
532+
MachineIRBuilder &B, GISelChangeObserver &Observer,
533+
std::tuple<bool, Register, Register> &MatchInfo) {
534+
assert(MI.getOpcode() == TargetOpcode::G_MUL &&
535+
"Expected a G_MUL instruction");
536+
537+
// Get the instructions that defined the source operand
538+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
539+
bool IsZExt = get<0>(MatchInfo);
540+
Register Src1Reg = get<1>(MatchInfo);
541+
Register Src2Reg = get<2>(MatchInfo);
542+
LLT Src1Ty = MRI.getType(Src1Reg);
543+
LLT Src2Ty = MRI.getType(Src2Reg);
544+
LLT HalfDstTy = DstTy.changeElementSize(DstTy.getScalarSizeInBits() / 2);
545+
unsigned ExtOpc = IsZExt ? TargetOpcode::G_ZEXT : TargetOpcode::G_SEXT;
546+
547+
if (Src1Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
548+
Src1Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src1Reg}).getReg(0);
549+
if (Src2Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
550+
Src2Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src2Reg}).getReg(0);
551+
552+
B.buildInstr(IsZExt ? AArch64::G_UMULL : AArch64::G_SMULL,
553+
{MI.getOperand(0).getReg()}, {Src1Reg, Src2Reg});
554+
MI.eraseFromParent();
555+
}
556+
441557
class AArch64PostLegalizerCombinerImpl : public Combiner {
442558
protected:
443559
const CombinerHelper Helper;

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp

Lines changed: 9 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,68 +1190,24 @@ void applyUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
11901190
// Doing these two matches in one function to ensure that the order of matching
11911191
// will always be the same.
11921192
// Try lowering MUL to MULL before trying to scalarize if needed.
1193-
bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI) {
1193+
bool matchMulv2s64(MachineInstr &MI, MachineRegisterInfo &MRI) {
11941194
// Get the instructions that defined the source operand
11951195
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1196-
MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
1197-
MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
1198-
1199-
if (DstTy.isVector()) {
1200-
// If the source operands were EXTENDED before, then {U/S}MULL can be used
1201-
unsigned I1Opc = I1->getOpcode();
1202-
unsigned I2Opc = I2->getOpcode();
1203-
if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
1204-
(I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
1205-
(MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
1206-
MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
1207-
(MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
1208-
MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
1209-
return true;
1210-
}
1211-
// If result type is v2s64, scalarise the instruction
1212-
else if (DstTy == LLT::fixed_vector(2, 64)) {
1213-
return true;
1214-
}
1215-
}
1216-
return false;
1196+
return DstTy == LLT::fixed_vector(2, 64);
12171197
}
12181198

1219-
void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
1220-
MachineIRBuilder &B, GISelChangeObserver &Observer) {
1199+
void applyMulv2s64(MachineInstr &MI, MachineRegisterInfo &MRI,
1200+
MachineIRBuilder &B, GISelChangeObserver &Observer) {
12211201
assert(MI.getOpcode() == TargetOpcode::G_MUL &&
12221202
"Expected a G_MUL instruction");
12231203

12241204
// Get the instructions that defined the source operand
12251205
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1226-
MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
1227-
MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
1228-
1229-
// If the source operands were EXTENDED before, then {U/S}MULL can be used
1230-
unsigned I1Opc = I1->getOpcode();
1231-
unsigned I2Opc = I2->getOpcode();
1232-
if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
1233-
(I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
1234-
(MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
1235-
MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
1236-
(MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
1237-
MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
1238-
1239-
B.setInstrAndDebugLoc(MI);
1240-
B.buildInstr(I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UMULL
1241-
: AArch64::G_SMULL,
1242-
{MI.getOperand(0).getReg()},
1243-
{I1->getOperand(1).getReg(), I2->getOperand(1).getReg()});
1244-
MI.eraseFromParent();
1245-
}
1246-
// If result type is v2s64, scalarise the instruction
1247-
else if (DstTy == LLT::fixed_vector(2, 64)) {
1248-
LegalizerHelper Helper(*MI.getMF(), Observer, B);
1249-
B.setInstrAndDebugLoc(MI);
1250-
Helper.fewerElementsVector(
1251-
MI, 0,
1252-
DstTy.changeElementCount(
1253-
DstTy.getElementCount().divideCoefficientBy(2)));
1254-
}
1206+
assert(DstTy == LLT::fixed_vector(2, 64) && "Expected v2s64 Mul");
1207+
LegalizerHelper Helper(*MI.getMF(), Observer, B);
1208+
Helper.fewerElementsVector(
1209+
MI, 0,
1210+
DstTy.changeElementCount(DstTy.getElementCount().divideCoefficientBy(2)));
12551211
}
12561212

12571213
class AArch64PostLegalizerLoweringImpl : public Combiner {

0 commit comments

Comments
 (0)