Skip to content

Commit 1af7725

Browse files
committed
[RISCV] Add MachineCombiner to fold (sh3add Z, (add X, (slli Y, 6))) -> (sh3add (sh3add Y, Z), X).
This is an alternative to the new pass proposed in llvm#87544. This improves a pattern that occurs in 531.deepsjeng_r. Reducing the dynamic instruction count by 0.5%. This may be possible to improve in SelectionDAG, but given the special cases around shXadd formation, it's not obvious it can be done in a robust way without adding multiple special cases. I've used a GEP with 2 indices because that mostly closely resembles the motivating case. Most of the test cases are the simplest GEP case. One test has a logical right shift on an index which is closer to the deepsjeng code. This requires special handling in isel to reverse a DAGCombiner canonicalization that turns a pair of shifts into (srl (and X, C1), C2). See also llvm#85734 which had a hacky version of a similar optimization.
1 parent 5748ad8 commit 1af7725

File tree

3 files changed

+177
-24
lines changed

3 files changed

+177
-24
lines changed

llvm/include/llvm/CodeGen/MachineCombinerPattern.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ enum class MachineCombinerPattern {
175175
FMADD_XA,
176176
FMSUB,
177177
FNMSUB,
178+
SHXADD_ADD_SLLI_OP1,
179+
SHXADD_ADD_SLLI_OP2,
178180

179181
// X86 VNNI
180182
DPWSSD,

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,13 +1829,94 @@ static bool getFPPatterns(MachineInstr &Root,
18291829
return getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
18301830
}
18311831

1832+
/// Utility routine that checks if \param MO is defined by an
1833+
/// \param CombineOpc instruction in the basic block \param MBB
1834+
static const MachineInstr *canCombine(const MachineBasicBlock &MBB,
1835+
const MachineOperand &MO,
1836+
unsigned CombineOpc) {
1837+
const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
1838+
const MachineInstr *MI = nullptr;
1839+
1840+
if (MO.isReg() && MO.getReg().isVirtual())
1841+
MI = MRI.getUniqueVRegDef(MO.getReg());
1842+
// And it needs to be in the trace (otherwise, it won't have a depth).
1843+
if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc)
1844+
return nullptr;
1845+
// Must only used by the user we combine with.
1846+
if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
1847+
return nullptr;
1848+
1849+
return MI;
1850+
}
1851+
1852+
/// Utility routine that checks if \param MO is defined by a SLLI in \param
1853+
/// MBB that can be combined by splitting across 2 SHXADD instructions. The
1854+
/// first SHXADD shift amount is given by \param OuterShiftAmt.
1855+
static bool canCombineShiftIntoShXAdd(const MachineBasicBlock &MBB,
1856+
const MachineOperand &MO,
1857+
unsigned OuterShiftAmt) {
1858+
const MachineInstr *ShiftMI = canCombine(MBB, MO, RISCV::SLLI);
1859+
if (!ShiftMI)
1860+
return false;
1861+
1862+
unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
1863+
if (InnerShiftAmt < OuterShiftAmt || (InnerShiftAmt - OuterShiftAmt) > 3)
1864+
return false;
1865+
1866+
return true;
1867+
}
1868+
1869+
// Look for opportunities to combine (sh3add Z, (add X, (slli Y, 5))) into
1870+
// (sh3add (sh2add Y, Z), X).
1871+
static bool
1872+
getSHXADDPatterns(const MachineInstr &Root,
1873+
SmallVectorImpl<MachineCombinerPattern> &Patterns) {
1874+
unsigned Opc = Root.getOpcode();
1875+
1876+
unsigned ShiftAmt;
1877+
switch (Opc) {
1878+
default:
1879+
return false;
1880+
case RISCV::SH1ADD:
1881+
ShiftAmt = 1;
1882+
break;
1883+
case RISCV::SH2ADD:
1884+
ShiftAmt = 2;
1885+
break;
1886+
case RISCV::SH3ADD:
1887+
ShiftAmt = 3;
1888+
break;
1889+
}
1890+
1891+
const MachineBasicBlock &MBB = *Root.getParent();
1892+
1893+
const MachineInstr *AddMI = canCombine(MBB, Root.getOperand(2), RISCV::ADD);
1894+
if (!AddMI)
1895+
return false;
1896+
1897+
bool Found = false;
1898+
if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(1), ShiftAmt)) {
1899+
Patterns.push_back(MachineCombinerPattern::SHXADD_ADD_SLLI_OP1);
1900+
Found = true;
1901+
}
1902+
if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(2), ShiftAmt)) {
1903+
Patterns.push_back(MachineCombinerPattern::SHXADD_ADD_SLLI_OP2);
1904+
Found = true;
1905+
}
1906+
1907+
return Found;
1908+
}
1909+
18321910
bool RISCVInstrInfo::getMachineCombinerPatterns(
18331911
MachineInstr &Root, SmallVectorImpl<MachineCombinerPattern> &Patterns,
18341912
bool DoRegPressureReduce) const {
18351913

18361914
if (getFPPatterns(Root, Patterns, DoRegPressureReduce))
18371915
return true;
18381916

1917+
if (getSHXADDPatterns(Root, Patterns))
1918+
return true;
1919+
18391920
return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
18401921
DoRegPressureReduce);
18411922
}
@@ -1918,6 +1999,78 @@ static void combineFPFusedMultiply(MachineInstr &Root, MachineInstr &Prev,
19181999
DelInstrs.push_back(&Root);
19192000
}
19202001

2002+
// Combine (sh3add Z, (add X, (slli Y, 5))) to (sh3add (sh2add Y, Z), X).
2003+
static void
2004+
genShXAddAddShift(MachineInstr &Root, unsigned AddOpIdx,
2005+
SmallVectorImpl<MachineInstr *> &InsInstrs,
2006+
SmallVectorImpl<MachineInstr *> &DelInstrs,
2007+
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) {
2008+
MachineFunction *MF = Root.getMF();
2009+
MachineRegisterInfo &MRI = MF->getRegInfo();
2010+
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
2011+
2012+
unsigned OuterShiftAmt;
2013+
switch (Root.getOpcode()) {
2014+
default:
2015+
llvm_unreachable("Unexpected opcode");
2016+
case RISCV::SH1ADD:
2017+
OuterShiftAmt = 1;
2018+
break;
2019+
case RISCV::SH2ADD:
2020+
OuterShiftAmt = 2;
2021+
break;
2022+
case RISCV::SH3ADD:
2023+
OuterShiftAmt = 3;
2024+
break;
2025+
}
2026+
2027+
MachineInstr *AddMI = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
2028+
MachineInstr *ShiftMI =
2029+
MRI.getUniqueVRegDef(AddMI->getOperand(AddOpIdx).getReg());
2030+
2031+
unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
2032+
assert(InnerShiftAmt > OuterShiftAmt && "Unexpected shift amount");
2033+
2034+
unsigned InnerOpc;
2035+
switch (InnerShiftAmt - OuterShiftAmt) {
2036+
default:
2037+
llvm_unreachable("Unexpected shift amount");
2038+
case 0:
2039+
InnerOpc = RISCV::ADD;
2040+
break;
2041+
case 1:
2042+
InnerOpc = RISCV::SH1ADD;
2043+
break;
2044+
case 2:
2045+
InnerOpc = RISCV::SH2ADD;
2046+
break;
2047+
case 3:
2048+
InnerOpc = RISCV::SH3ADD;
2049+
break;
2050+
}
2051+
2052+
Register X = AddMI->getOperand(3 - AddOpIdx).getReg();
2053+
Register Y = ShiftMI->getOperand(1).getReg();
2054+
Register Z = Root.getOperand(1).getReg();
2055+
2056+
Register NewVR = MRI.createVirtualRegister(&RISCV::GPRRegClass);
2057+
2058+
auto MIB1 = BuildMI(*MF, MIMetadata(Root), TII->get(InnerOpc), NewVR)
2059+
.addReg(Y)
2060+
.addReg(Z);
2061+
auto MIB2 = BuildMI(*MF, MIMetadata(Root), TII->get(Root.getOpcode()),
2062+
Root.getOperand(0).getReg())
2063+
.addReg(NewVR)
2064+
.addReg(X);
2065+
2066+
InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
2067+
InsInstrs.push_back(MIB1);
2068+
InsInstrs.push_back(MIB2);
2069+
DelInstrs.push_back(ShiftMI);
2070+
DelInstrs.push_back(AddMI);
2071+
DelInstrs.push_back(&Root);
2072+
}
2073+
19212074
void RISCVInstrInfo::genAlternativeCodeSequence(
19222075
MachineInstr &Root, MachineCombinerPattern Pattern,
19232076
SmallVectorImpl<MachineInstr *> &InsInstrs,
@@ -1941,6 +2094,12 @@ void RISCVInstrInfo::genAlternativeCodeSequence(
19412094
combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs);
19422095
return;
19432096
}
2097+
case MachineCombinerPattern::SHXADD_ADD_SLLI_OP1:
2098+
genShXAddAddShift(Root, 1, InsInstrs, DelInstrs, InstrIdxForVirtReg);
2099+
return;
2100+
case MachineCombinerPattern::SHXADD_ADD_SLLI_OP2:
2101+
genShXAddAddShift(Root, 2, InsInstrs, DelInstrs, InstrIdxForVirtReg);
2102+
return;
19442103
}
19452104
}
19462105

llvm/test/CodeGen/RISCV/rv64zba.ll

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,9 +1404,8 @@ define i64 @sh6_sh3_add2(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
14041404
;
14051405
; RV64ZBA-LABEL: sh6_sh3_add2:
14061406
; RV64ZBA: # %bb.0: # %entry
1407-
; RV64ZBA-NEXT: slli a1, a1, 6
1408-
; RV64ZBA-NEXT: add a0, a1, a0
1409-
; RV64ZBA-NEXT: sh3add a0, a2, a0
1407+
; RV64ZBA-NEXT: sh3add a1, a1, a2
1408+
; RV64ZBA-NEXT: sh3add a0, a1, a0
14101409
; RV64ZBA-NEXT: ret
14111410
entry:
14121411
%shl = shl i64 %z, 3
@@ -2111,9 +2110,8 @@ define i64 @array_index_sh1_sh3(ptr %p, i64 %idx1, i64 %idx2) {
21112110
;
21122111
; RV64ZBA-LABEL: array_index_sh1_sh3:
21132112
; RV64ZBA: # %bb.0:
2114-
; RV64ZBA-NEXT: slli a1, a1, 4
2115-
; RV64ZBA-NEXT: add a0, a0, a1
2116-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2113+
; RV64ZBA-NEXT: sh1add a1, a1, a2
2114+
; RV64ZBA-NEXT: sh3add a0, a1, a0
21172115
; RV64ZBA-NEXT: ld a0, 0(a0)
21182116
; RV64ZBA-NEXT: ret
21192117
%a = getelementptr inbounds [2 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2174,9 +2172,8 @@ define i32 @array_index_sh2_sh2(ptr %p, i64 %idx1, i64 %idx2) {
21742172
;
21752173
; RV64ZBA-LABEL: array_index_sh2_sh2:
21762174
; RV64ZBA: # %bb.0:
2177-
; RV64ZBA-NEXT: slli a1, a1, 4
2178-
; RV64ZBA-NEXT: add a0, a0, a1
2179-
; RV64ZBA-NEXT: sh2add a0, a2, a0
2175+
; RV64ZBA-NEXT: sh2add a1, a1, a2
2176+
; RV64ZBA-NEXT: sh2add a0, a1, a0
21802177
; RV64ZBA-NEXT: lw a0, 0(a0)
21812178
; RV64ZBA-NEXT: ret
21822179
%a = getelementptr inbounds [4 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2196,9 +2193,8 @@ define i64 @array_index_sh2_sh3(ptr %p, i64 %idx1, i64 %idx2) {
21962193
;
21972194
; RV64ZBA-LABEL: array_index_sh2_sh3:
21982195
; RV64ZBA: # %bb.0:
2199-
; RV64ZBA-NEXT: slli a1, a1, 5
2200-
; RV64ZBA-NEXT: add a0, a0, a1
2201-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2196+
; RV64ZBA-NEXT: sh2add a1, a1, a2
2197+
; RV64ZBA-NEXT: sh3add a0, a1, a0
22022198
; RV64ZBA-NEXT: ld a0, 0(a0)
22032199
; RV64ZBA-NEXT: ret
22042200
%a = getelementptr inbounds [4 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2238,9 +2234,8 @@ define i16 @array_index_sh3_sh1(ptr %p, i64 %idx1, i64 %idx2) {
22382234
;
22392235
; RV64ZBA-LABEL: array_index_sh3_sh1:
22402236
; RV64ZBA: # %bb.0:
2241-
; RV64ZBA-NEXT: slli a1, a1, 4
2242-
; RV64ZBA-NEXT: add a0, a0, a1
2243-
; RV64ZBA-NEXT: sh1add a0, a2, a0
2237+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2238+
; RV64ZBA-NEXT: sh1add a0, a1, a0
22442239
; RV64ZBA-NEXT: lh a0, 0(a0)
22452240
; RV64ZBA-NEXT: ret
22462241
%a = getelementptr inbounds [8 x i16], ptr %p, i64 %idx1, i64 %idx2
@@ -2260,9 +2255,8 @@ define i32 @array_index_sh3_sh2(ptr %p, i64 %idx1, i64 %idx2) {
22602255
;
22612256
; RV64ZBA-LABEL: array_index_sh3_sh2:
22622257
; RV64ZBA: # %bb.0:
2263-
; RV64ZBA-NEXT: slli a1, a1, 5
2264-
; RV64ZBA-NEXT: add a0, a0, a1
2265-
; RV64ZBA-NEXT: sh2add a0, a2, a0
2258+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2259+
; RV64ZBA-NEXT: sh2add a0, a1, a0
22662260
; RV64ZBA-NEXT: lw a0, 0(a0)
22672261
; RV64ZBA-NEXT: ret
22682262
%a = getelementptr inbounds [8 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2282,9 +2276,8 @@ define i64 @array_index_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
22822276
;
22832277
; RV64ZBA-LABEL: array_index_sh3_sh3:
22842278
; RV64ZBA: # %bb.0:
2285-
; RV64ZBA-NEXT: slli a1, a1, 6
2286-
; RV64ZBA-NEXT: add a0, a0, a1
2287-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2279+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2280+
; RV64ZBA-NEXT: sh3add a0, a1, a0
22882281
; RV64ZBA-NEXT: ld a0, 0(a0)
22892282
; RV64ZBA-NEXT: ret
22902283
%a = getelementptr inbounds [8 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2308,9 +2301,8 @@ define i64 @array_index_lshr_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
23082301
; RV64ZBA-LABEL: array_index_lshr_sh3_sh3:
23092302
; RV64ZBA: # %bb.0:
23102303
; RV64ZBA-NEXT: srli a1, a1, 58
2311-
; RV64ZBA-NEXT: slli a1, a1, 6
2312-
; RV64ZBA-NEXT: add a0, a0, a1
2313-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2304+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2305+
; RV64ZBA-NEXT: sh3add a0, a1, a0
23142306
; RV64ZBA-NEXT: ld a0, 0(a0)
23152307
; RV64ZBA-NEXT: ret
23162308
%shr = lshr i64 %idx1, 58

0 commit comments

Comments
 (0)