Skip to content

Commit 7f1b9ad

Browse files
authored
[RISCV] Add MachineCombiner to fold (sh3add Z, (add X, (slli Y, 6))) -> (sh3add (sh3add Y, Z), X). (llvm#87884)
This improves a pattern that occurs in 531.deepsjeng_r. Reducing the dynamic instruction count by 0.5%. This may be possible to improve in SelectionDAG, but given the special cases around shXadd formation, it's not obvious it can be done in a robust way without adding multiple special cases. I've used a GEP with 2 indices because that mostly closely resembles the motivating case. Most of the test cases are the simplest GEP case. One test has a logical right shift on an index which is closer to the deepsjeng code. This requires special handling in isel to reverse a DAGCombiner canonicalization that turns a pair of shifts into (srl (and X, C1), C2).
1 parent 6ca5a41 commit 7f1b9ad

File tree

3 files changed

+169
-24
lines changed

3 files changed

+169
-24
lines changed

llvm/include/llvm/CodeGen/MachineCombinerPattern.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ enum class MachineCombinerPattern {
175175
FMADD_XA,
176176
FMSUB,
177177
FNMSUB,
178+
SHXADD_ADD_SLLI_OP1,
179+
SHXADD_ADD_SLLI_OP2,
178180

179181
// X86 VNNI
180182
DPWSSD,

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,13 +1775,96 @@ static bool getFPPatterns(MachineInstr &Root,
17751775
return getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
17761776
}
17771777

1778+
/// Utility routine that checks if \param MO is defined by an
1779+
/// \param CombineOpc instruction in the basic block \param MBB
1780+
static const MachineInstr *canCombine(const MachineBasicBlock &MBB,
1781+
const MachineOperand &MO,
1782+
unsigned CombineOpc) {
1783+
const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
1784+
const MachineInstr *MI = nullptr;
1785+
1786+
if (MO.isReg() && MO.getReg().isVirtual())
1787+
MI = MRI.getUniqueVRegDef(MO.getReg());
1788+
// And it needs to be in the trace (otherwise, it won't have a depth).
1789+
if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc)
1790+
return nullptr;
1791+
// Must only used by the user we combine with.
1792+
if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
1793+
return nullptr;
1794+
1795+
return MI;
1796+
}
1797+
1798+
/// Utility routine that checks if \param MO is defined by a SLLI in \param
1799+
/// MBB that can be combined by splitting across 2 SHXADD instructions. The
1800+
/// first SHXADD shift amount is given by \param OuterShiftAmt.
1801+
static bool canCombineShiftIntoShXAdd(const MachineBasicBlock &MBB,
1802+
const MachineOperand &MO,
1803+
unsigned OuterShiftAmt) {
1804+
const MachineInstr *ShiftMI = canCombine(MBB, MO, RISCV::SLLI);
1805+
if (!ShiftMI)
1806+
return false;
1807+
1808+
unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
1809+
if (InnerShiftAmt < OuterShiftAmt || (InnerShiftAmt - OuterShiftAmt) > 3)
1810+
return false;
1811+
1812+
return true;
1813+
}
1814+
1815+
// Returns the shift amount from a SHXADD instruction. Returns 0 if the
1816+
// instruction is not a SHXADD.
1817+
static unsigned getSHXADDShiftAmount(unsigned Opc) {
1818+
switch (Opc) {
1819+
default:
1820+
return 0;
1821+
case RISCV::SH1ADD:
1822+
return 1;
1823+
case RISCV::SH2ADD:
1824+
return 2;
1825+
case RISCV::SH3ADD:
1826+
return 3;
1827+
}
1828+
}
1829+
1830+
// Look for opportunities to combine (sh3add Z, (add X, (slli Y, 5))) into
1831+
// (sh3add (sh2add Y, Z), X).
1832+
static bool
1833+
getSHXADDPatterns(const MachineInstr &Root,
1834+
SmallVectorImpl<MachineCombinerPattern> &Patterns) {
1835+
unsigned ShiftAmt = getSHXADDShiftAmount(Root.getOpcode());
1836+
if (!ShiftAmt)
1837+
return false;
1838+
1839+
const MachineBasicBlock &MBB = *Root.getParent();
1840+
1841+
const MachineInstr *AddMI = canCombine(MBB, Root.getOperand(2), RISCV::ADD);
1842+
if (!AddMI)
1843+
return false;
1844+
1845+
bool Found = false;
1846+
if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(1), ShiftAmt)) {
1847+
Patterns.push_back(MachineCombinerPattern::SHXADD_ADD_SLLI_OP1);
1848+
Found = true;
1849+
}
1850+
if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(2), ShiftAmt)) {
1851+
Patterns.push_back(MachineCombinerPattern::SHXADD_ADD_SLLI_OP2);
1852+
Found = true;
1853+
}
1854+
1855+
return Found;
1856+
}
1857+
17781858
bool RISCVInstrInfo::getMachineCombinerPatterns(
17791859
MachineInstr &Root, SmallVectorImpl<MachineCombinerPattern> &Patterns,
17801860
bool DoRegPressureReduce) const {
17811861

17821862
if (getFPPatterns(Root, Patterns, DoRegPressureReduce))
17831863
return true;
17841864

1865+
if (getSHXADDPatterns(Root, Patterns))
1866+
return true;
1867+
17851868
return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
17861869
DoRegPressureReduce);
17871870
}
@@ -1864,6 +1947,68 @@ static void combineFPFusedMultiply(MachineInstr &Root, MachineInstr &Prev,
18641947
DelInstrs.push_back(&Root);
18651948
}
18661949

1950+
// Combine patterns like (sh3add Z, (add X, (slli Y, 5))) to
1951+
// (sh3add (sh2add Y, Z), X) if the shift amount can be split across two
1952+
// shXadd instructions. The outer shXadd keeps its original opcode.
1953+
static void
1954+
genShXAddAddShift(MachineInstr &Root, unsigned AddOpIdx,
1955+
SmallVectorImpl<MachineInstr *> &InsInstrs,
1956+
SmallVectorImpl<MachineInstr *> &DelInstrs,
1957+
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) {
1958+
MachineFunction *MF = Root.getMF();
1959+
MachineRegisterInfo &MRI = MF->getRegInfo();
1960+
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
1961+
1962+
unsigned OuterShiftAmt = getSHXADDShiftAmount(Root.getOpcode());
1963+
assert(OuterShiftAmt != 0 && "Unexpected opcode");
1964+
1965+
MachineInstr *AddMI = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
1966+
MachineInstr *ShiftMI =
1967+
MRI.getUniqueVRegDef(AddMI->getOperand(AddOpIdx).getReg());
1968+
1969+
unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
1970+
assert(InnerShiftAmt > OuterShiftAmt && "Unexpected shift amount");
1971+
1972+
unsigned InnerOpc;
1973+
switch (InnerShiftAmt - OuterShiftAmt) {
1974+
default:
1975+
llvm_unreachable("Unexpected shift amount");
1976+
case 0:
1977+
InnerOpc = RISCV::ADD;
1978+
break;
1979+
case 1:
1980+
InnerOpc = RISCV::SH1ADD;
1981+
break;
1982+
case 2:
1983+
InnerOpc = RISCV::SH2ADD;
1984+
break;
1985+
case 3:
1986+
InnerOpc = RISCV::SH3ADD;
1987+
break;
1988+
}
1989+
1990+
const MachineOperand &X = AddMI->getOperand(3 - AddOpIdx);
1991+
const MachineOperand &Y = ShiftMI->getOperand(1);
1992+
const MachineOperand &Z = Root.getOperand(1);
1993+
1994+
Register NewVR = MRI.createVirtualRegister(&RISCV::GPRRegClass);
1995+
1996+
auto MIB1 = BuildMI(*MF, MIMetadata(Root), TII->get(InnerOpc), NewVR)
1997+
.addReg(Y.getReg(), getKillRegState(Y.isKill()))
1998+
.addReg(Z.getReg(), getKillRegState(Z.isKill()));
1999+
auto MIB2 = BuildMI(*MF, MIMetadata(Root), TII->get(Root.getOpcode()),
2000+
Root.getOperand(0).getReg())
2001+
.addReg(NewVR, RegState::Kill)
2002+
.addReg(X.getReg(), getKillRegState(X.isKill()));
2003+
2004+
InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
2005+
InsInstrs.push_back(MIB1);
2006+
InsInstrs.push_back(MIB2);
2007+
DelInstrs.push_back(ShiftMI);
2008+
DelInstrs.push_back(AddMI);
2009+
DelInstrs.push_back(&Root);
2010+
}
2011+
18672012
void RISCVInstrInfo::genAlternativeCodeSequence(
18682013
MachineInstr &Root, MachineCombinerPattern Pattern,
18692014
SmallVectorImpl<MachineInstr *> &InsInstrs,
@@ -1887,6 +2032,12 @@ void RISCVInstrInfo::genAlternativeCodeSequence(
18872032
combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs);
18882033
return;
18892034
}
2035+
case MachineCombinerPattern::SHXADD_ADD_SLLI_OP1:
2036+
genShXAddAddShift(Root, 1, InsInstrs, DelInstrs, InstrIdxForVirtReg);
2037+
return;
2038+
case MachineCombinerPattern::SHXADD_ADD_SLLI_OP2:
2039+
genShXAddAddShift(Root, 2, InsInstrs, DelInstrs, InstrIdxForVirtReg);
2040+
return;
18902041
}
18912042
}
18922043

llvm/test/CodeGen/RISCV/rv64zba.ll

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,9 +1404,8 @@ define i64 @sh6_sh3_add2(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
14041404
;
14051405
; RV64ZBA-LABEL: sh6_sh3_add2:
14061406
; RV64ZBA: # %bb.0: # %entry
1407-
; RV64ZBA-NEXT: slli a1, a1, 6
1408-
; RV64ZBA-NEXT: add a0, a1, a0
1409-
; RV64ZBA-NEXT: sh3add a0, a2, a0
1407+
; RV64ZBA-NEXT: sh3add a1, a1, a2
1408+
; RV64ZBA-NEXT: sh3add a0, a1, a0
14101409
; RV64ZBA-NEXT: ret
14111410
entry:
14121411
%shl = shl i64 %z, 3
@@ -2111,9 +2110,8 @@ define i64 @array_index_sh1_sh3(ptr %p, i64 %idx1, i64 %idx2) {
21112110
;
21122111
; RV64ZBA-LABEL: array_index_sh1_sh3:
21132112
; RV64ZBA: # %bb.0:
2114-
; RV64ZBA-NEXT: slli a1, a1, 4
2115-
; RV64ZBA-NEXT: add a0, a0, a1
2116-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2113+
; RV64ZBA-NEXT: sh1add a1, a1, a2
2114+
; RV64ZBA-NEXT: sh3add a0, a1, a0
21172115
; RV64ZBA-NEXT: ld a0, 0(a0)
21182116
; RV64ZBA-NEXT: ret
21192117
%a = getelementptr inbounds [2 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2174,9 +2172,8 @@ define i32 @array_index_sh2_sh2(ptr %p, i64 %idx1, i64 %idx2) {
21742172
;
21752173
; RV64ZBA-LABEL: array_index_sh2_sh2:
21762174
; RV64ZBA: # %bb.0:
2177-
; RV64ZBA-NEXT: slli a1, a1, 4
2178-
; RV64ZBA-NEXT: add a0, a0, a1
2179-
; RV64ZBA-NEXT: sh2add a0, a2, a0
2175+
; RV64ZBA-NEXT: sh2add a1, a1, a2
2176+
; RV64ZBA-NEXT: sh2add a0, a1, a0
21802177
; RV64ZBA-NEXT: lw a0, 0(a0)
21812178
; RV64ZBA-NEXT: ret
21822179
%a = getelementptr inbounds [4 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2196,9 +2193,8 @@ define i64 @array_index_sh2_sh3(ptr %p, i64 %idx1, i64 %idx2) {
21962193
;
21972194
; RV64ZBA-LABEL: array_index_sh2_sh3:
21982195
; RV64ZBA: # %bb.0:
2199-
; RV64ZBA-NEXT: slli a1, a1, 5
2200-
; RV64ZBA-NEXT: add a0, a0, a1
2201-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2196+
; RV64ZBA-NEXT: sh2add a1, a1, a2
2197+
; RV64ZBA-NEXT: sh3add a0, a1, a0
22022198
; RV64ZBA-NEXT: ld a0, 0(a0)
22032199
; RV64ZBA-NEXT: ret
22042200
%a = getelementptr inbounds [4 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2238,9 +2234,8 @@ define i16 @array_index_sh3_sh1(ptr %p, i64 %idx1, i64 %idx2) {
22382234
;
22392235
; RV64ZBA-LABEL: array_index_sh3_sh1:
22402236
; RV64ZBA: # %bb.0:
2241-
; RV64ZBA-NEXT: slli a1, a1, 4
2242-
; RV64ZBA-NEXT: add a0, a0, a1
2243-
; RV64ZBA-NEXT: sh1add a0, a2, a0
2237+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2238+
; RV64ZBA-NEXT: sh1add a0, a1, a0
22442239
; RV64ZBA-NEXT: lh a0, 0(a0)
22452240
; RV64ZBA-NEXT: ret
22462241
%a = getelementptr inbounds [8 x i16], ptr %p, i64 %idx1, i64 %idx2
@@ -2260,9 +2255,8 @@ define i32 @array_index_sh3_sh2(ptr %p, i64 %idx1, i64 %idx2) {
22602255
;
22612256
; RV64ZBA-LABEL: array_index_sh3_sh2:
22622257
; RV64ZBA: # %bb.0:
2263-
; RV64ZBA-NEXT: slli a1, a1, 5
2264-
; RV64ZBA-NEXT: add a0, a0, a1
2265-
; RV64ZBA-NEXT: sh2add a0, a2, a0
2258+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2259+
; RV64ZBA-NEXT: sh2add a0, a1, a0
22662260
; RV64ZBA-NEXT: lw a0, 0(a0)
22672261
; RV64ZBA-NEXT: ret
22682262
%a = getelementptr inbounds [8 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2282,9 +2276,8 @@ define i64 @array_index_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
22822276
;
22832277
; RV64ZBA-LABEL: array_index_sh3_sh3:
22842278
; RV64ZBA: # %bb.0:
2285-
; RV64ZBA-NEXT: slli a1, a1, 6
2286-
; RV64ZBA-NEXT: add a0, a0, a1
2287-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2279+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2280+
; RV64ZBA-NEXT: sh3add a0, a1, a0
22882281
; RV64ZBA-NEXT: ld a0, 0(a0)
22892282
; RV64ZBA-NEXT: ret
22902283
%a = getelementptr inbounds [8 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2308,9 +2301,8 @@ define i64 @array_index_lshr_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
23082301
; RV64ZBA-LABEL: array_index_lshr_sh3_sh3:
23092302
; RV64ZBA: # %bb.0:
23102303
; RV64ZBA-NEXT: srli a1, a1, 58
2311-
; RV64ZBA-NEXT: slli a1, a1, 6
2312-
; RV64ZBA-NEXT: add a0, a0, a1
2313-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2304+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2305+
; RV64ZBA-NEXT: sh3add a0, a1, a0
23142306
; RV64ZBA-NEXT: ld a0, 0(a0)
23152307
; RV64ZBA-NEXT: ret
23162308
%shr = lshr i64 %idx1, 58

0 commit comments

Comments
 (0)