Skip to content

Commit 5d24969

Browse files
committed
[RISCV] Add a MIR pass to reassociate shXadd, add, and slli to form more shXadd.
This reassociates patterns like (sh3add Z, (add X, (slli Y, 6))) into (sh3add (sh3add Y, Z), X). This improves a pattern that occurs in 531.deepsjeng_r. Reducing the dynamic instruction count by 0.5%. This may be possible to improve in SelectionDAG, but given the special cases around shXadd formation, it's not obvious it can be done in a robust way without adding multiple special cases. I've used a GEP with 2 indices because that mostly closely resembles the motivating case. Most of the test cases are the simplest GEP case. One test has a logical right shift on an index which is closer to the deepsjeng code. This requires special handling in isel to reverse a DAGCombiner canonicalization that turns a pair of shifts into (srl (and X, C1), C2). See also #85734 which had a hacky version of a similar optimization.
1 parent 4abb722 commit 5d24969

File tree

6 files changed

+194
-24
lines changed

6 files changed

+194
-24
lines changed

llvm/lib/Target/RISCV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ add_llvm_target(RISCVCodeGen
4646
RISCVMachineFunctionInfo.cpp
4747
RISCVMergeBaseOffset.cpp
4848
RISCVOptWInstrs.cpp
49+
RISCVOptZba.cpp
4950
RISCVPostRAExpandPseudoInsts.cpp
5051
RISCVRedundantCopyElimination.cpp
5152
RISCVMoveMerger.cpp

llvm/lib/Target/RISCV/RISCV.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ void initializeRISCVFoldMasksPass(PassRegistry &);
4646
FunctionPass *createRISCVOptWInstrsPass();
4747
void initializeRISCVOptWInstrsPass(PassRegistry &);
4848

49+
FunctionPass *createRISCVOptZbaPass();
50+
void initializeRISCVOptZbaPass(PassRegistry &);
51+
4952
FunctionPass *createRISCVMergeBaseOffsetOptPass();
5053
void initializeRISCVMergeBaseOffsetOptPass(PassRegistry &);
5154

llvm/lib/Target/RISCV/RISCVOptZba.cpp

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
//===- RISCVOptZba.cpp - MI Zba instruction optimizations -----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
//
9+
// This pass reassociates expressions like
10+
// (sh3add Z, (add X, (slli Y, 5)))
11+
// To
12+
// (sh3add (sh2add Y, Z), X)
13+
//
14+
// If the shift amount is small enough. The outer shXadd keeps its original
15+
// opcode. The inner shXadd shift amount is the difference between the slli
16+
// shift amount and the outer shXadd shift amount.
17+
//
18+
// This pattern can appear when indexing a two dimensional array, but it is not
19+
// limited to that.
20+
//
21+
// TODO: We can also support slli.uw by using shXadd.uw for the inner shXadd.
22+
// TODO: This can be generalized to deeper expressions.
23+
//
24+
//===---------------------------------------------------------------------===//
25+
26+
#include "RISCV.h"
27+
#include "RISCVSubtarget.h"
28+
#include "llvm/CodeGen/MachineFunctionPass.h"
29+
30+
using namespace llvm;
31+
32+
#define DEBUG_TYPE "riscv-opt-zba"
33+
#define RISCV_OPT_ZBA_NAME "RISC-V Optimize Zba"
34+
35+
namespace {
36+
37+
class RISCVOptZba : public MachineFunctionPass {
38+
public:
39+
static char ID;
40+
41+
RISCVOptZba() : MachineFunctionPass(ID) {}
42+
43+
bool runOnMachineFunction(MachineFunction &MF) override;
44+
45+
void getAnalysisUsage(AnalysisUsage &AU) const override {
46+
AU.setPreservesCFG();
47+
MachineFunctionPass::getAnalysisUsage(AU);
48+
}
49+
50+
StringRef getPassName() const override { return RISCV_OPT_ZBA_NAME; }
51+
};
52+
53+
} // end anonymous namespace
54+
55+
char RISCVOptZba::ID = 0;
56+
INITIALIZE_PASS(RISCVOptZba, DEBUG_TYPE, RISCV_OPT_ZBA_NAME, false, false)
57+
58+
FunctionPass *llvm::createRISCVOptZbaPass() { return new RISCVOptZba(); }
59+
60+
static MachineInstr *findShift(Register Reg, const MachineBasicBlock &MBB,
61+
MachineRegisterInfo &MRI) {
62+
if (!Reg.isVirtual())
63+
return nullptr;
64+
65+
MachineInstr *Shift = MRI.getVRegDef(Reg);
66+
if (!Shift || Shift->getOpcode() != RISCV::SLLI ||
67+
Shift->getParent() != &MBB || !MRI.hasOneNonDBGUse(Reg))
68+
return nullptr;
69+
70+
return Shift;
71+
}
72+
73+
bool RISCVOptZba::runOnMachineFunction(MachineFunction &MF) {
74+
if (skipFunction(MF.getFunction()))
75+
return false;
76+
77+
MachineRegisterInfo &MRI = MF.getRegInfo();
78+
const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
79+
const RISCVInstrInfo &TII = *ST.getInstrInfo();
80+
81+
if (!ST.hasStdExtZba())
82+
return false;
83+
84+
bool MadeChange = true;
85+
86+
for (MachineBasicBlock &MBB : MF) {
87+
for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
88+
unsigned OuterShiftAmt;
89+
switch (MI.getOpcode()) {
90+
default:
91+
continue;
92+
case RISCV::SH1ADD:
93+
OuterShiftAmt = 1;
94+
break;
95+
case RISCV::SH2ADD:
96+
OuterShiftAmt = 2;
97+
break;
98+
case RISCV::SH3ADD:
99+
OuterShiftAmt = 3;
100+
break;
101+
}
102+
103+
// Second operand must be virtual.
104+
Register UnshiftedReg = MI.getOperand(2).getReg();
105+
if (!UnshiftedReg.isVirtual())
106+
continue;
107+
108+
MachineInstr *Add = MRI.getVRegDef(UnshiftedReg);
109+
if (!Add || Add->getOpcode() != RISCV::ADD || Add->getParent() != &MBB ||
110+
!MRI.hasOneNonDBGUse(UnshiftedReg))
111+
continue;
112+
113+
Register AddReg0 = Add->getOperand(1).getReg();
114+
Register AddReg1 = Add->getOperand(2).getReg();
115+
116+
MachineInstr *InnerShift;
117+
Register X;
118+
if ((InnerShift = findShift(AddReg0, MBB, MRI)))
119+
X = AddReg1;
120+
else if ((InnerShift = findShift(AddReg1, MBB, MRI)))
121+
X = AddReg0;
122+
else
123+
continue;
124+
125+
unsigned InnerShiftAmt = InnerShift->getOperand(2).getImm();
126+
127+
// The inner shift amount must be at least as large as the outer shift
128+
// amount.
129+
if (OuterShiftAmt > InnerShiftAmt)
130+
continue;
131+
132+
unsigned InnerOpc;
133+
switch (InnerShiftAmt - OuterShiftAmt) {
134+
default:
135+
continue;
136+
case 0:
137+
InnerOpc = RISCV::ADD;
138+
break;
139+
case 1:
140+
InnerOpc = RISCV::SH1ADD;
141+
break;
142+
case 2:
143+
InnerOpc = RISCV::SH2ADD;
144+
break;
145+
case 3:
146+
InnerOpc = RISCV::SH3ADD;
147+
break;
148+
}
149+
150+
Register Y = InnerShift->getOperand(1).getReg();
151+
Register Z = MI.getOperand(1).getReg();
152+
153+
Register NewReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
154+
BuildMI(MBB, MI, MI.getDebugLoc(), TII.get(InnerOpc), NewReg)
155+
.addReg(Y)
156+
.addReg(Z);
157+
BuildMI(MBB, MI, MI.getDebugLoc(), TII.get(MI.getOpcode()),
158+
MI.getOperand(0).getReg())
159+
.addReg(NewReg)
160+
.addReg(X);
161+
162+
MI.eraseFromParent();
163+
Add->eraseFromParent();
164+
InnerShift->eraseFromParent();
165+
MadeChange = true;
166+
}
167+
}
168+
169+
return MadeChange;
170+
}

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() {
117117
initializeRISCVPostRAExpandPseudoPass(*PR);
118118
initializeRISCVMergeBaseOffsetOptPass(*PR);
119119
initializeRISCVOptWInstrsPass(*PR);
120+
initializeRISCVOptZbaPass(*PR);
120121
initializeRISCVPreRAExpandPseudoPass(*PR);
121122
initializeRISCVExpandPseudoPass(*PR);
122123
initializeRISCVFoldMasksPass(*PR);
@@ -531,6 +532,8 @@ void RISCVPassConfig::addMachineSSAOptimization() {
531532
if (EnableMachineCombiner)
532533
addPass(&MachineCombinerID);
533534

535+
addPass(createRISCVOptZbaPass());
536+
534537
if (TM->getTargetTriple().isRISCV64()) {
535538
addPass(createRISCVOptWInstrsPass());
536539
}

llvm/test/CodeGen/RISCV/O3-pipeline.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
; CHECK-NEXT: Machine Trace Metrics
113113
; CHECK-NEXT: Lazy Machine Block Frequency Analysis
114114
; CHECK-NEXT: Machine InstCombiner
115+
; CHECK-NEXT: RISC-V Optimize Zba
115116
; RV64-NEXT: RISC-V Optimize W Instructions
116117
; CHECK-NEXT: RISC-V Pre-RA pseudo instruction expansion pass
117118
; CHECK-NEXT: RISC-V Merge Base Offset

llvm/test/CodeGen/RISCV/rv64zba.ll

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,9 +1404,8 @@ define i64 @sh6_sh3_add2(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
14041404
;
14051405
; RV64ZBA-LABEL: sh6_sh3_add2:
14061406
; RV64ZBA: # %bb.0: # %entry
1407-
; RV64ZBA-NEXT: slli a1, a1, 6
1408-
; RV64ZBA-NEXT: add a0, a1, a0
1409-
; RV64ZBA-NEXT: sh3add a0, a2, a0
1407+
; RV64ZBA-NEXT: sh3add a1, a1, a2
1408+
; RV64ZBA-NEXT: sh3add a0, a1, a0
14101409
; RV64ZBA-NEXT: ret
14111410
entry:
14121411
%shl = shl i64 %z, 3
@@ -2111,9 +2110,8 @@ define i64 @array_index_sh1_sh3(ptr %p, i64 %idx1, i64 %idx2) {
21112110
;
21122111
; RV64ZBA-LABEL: array_index_sh1_sh3:
21132112
; RV64ZBA: # %bb.0:
2114-
; RV64ZBA-NEXT: slli a1, a1, 4
2115-
; RV64ZBA-NEXT: add a0, a0, a1
2116-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2113+
; RV64ZBA-NEXT: sh1add a1, a1, a2
2114+
; RV64ZBA-NEXT: sh3add a0, a1, a0
21172115
; RV64ZBA-NEXT: ld a0, 0(a0)
21182116
; RV64ZBA-NEXT: ret
21192117
%a = getelementptr inbounds [2 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2174,9 +2172,8 @@ define i32 @array_index_sh2_sh2(ptr %p, i64 %idx1, i64 %idx2) {
21742172
;
21752173
; RV64ZBA-LABEL: array_index_sh2_sh2:
21762174
; RV64ZBA: # %bb.0:
2177-
; RV64ZBA-NEXT: slli a1, a1, 4
2178-
; RV64ZBA-NEXT: add a0, a0, a1
2179-
; RV64ZBA-NEXT: sh2add a0, a2, a0
2175+
; RV64ZBA-NEXT: sh2add a1, a1, a2
2176+
; RV64ZBA-NEXT: sh2add a0, a1, a0
21802177
; RV64ZBA-NEXT: lw a0, 0(a0)
21812178
; RV64ZBA-NEXT: ret
21822179
%a = getelementptr inbounds [4 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2196,9 +2193,8 @@ define i64 @array_index_sh2_sh3(ptr %p, i64 %idx1, i64 %idx2) {
21962193
;
21972194
; RV64ZBA-LABEL: array_index_sh2_sh3:
21982195
; RV64ZBA: # %bb.0:
2199-
; RV64ZBA-NEXT: slli a1, a1, 5
2200-
; RV64ZBA-NEXT: add a0, a0, a1
2201-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2196+
; RV64ZBA-NEXT: sh2add a1, a1, a2
2197+
; RV64ZBA-NEXT: sh3add a0, a1, a0
22022198
; RV64ZBA-NEXT: ld a0, 0(a0)
22032199
; RV64ZBA-NEXT: ret
22042200
%a = getelementptr inbounds [4 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2238,9 +2234,8 @@ define i16 @array_index_sh3_sh1(ptr %p, i64 %idx1, i64 %idx2) {
22382234
;
22392235
; RV64ZBA-LABEL: array_index_sh3_sh1:
22402236
; RV64ZBA: # %bb.0:
2241-
; RV64ZBA-NEXT: slli a1, a1, 4
2242-
; RV64ZBA-NEXT: add a0, a0, a1
2243-
; RV64ZBA-NEXT: sh1add a0, a2, a0
2237+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2238+
; RV64ZBA-NEXT: sh1add a0, a1, a0
22442239
; RV64ZBA-NEXT: lh a0, 0(a0)
22452240
; RV64ZBA-NEXT: ret
22462241
%a = getelementptr inbounds [8 x i16], ptr %p, i64 %idx1, i64 %idx2
@@ -2260,9 +2255,8 @@ define i32 @array_index_sh3_sh2(ptr %p, i64 %idx1, i64 %idx2) {
22602255
;
22612256
; RV64ZBA-LABEL: array_index_sh3_sh2:
22622257
; RV64ZBA: # %bb.0:
2263-
; RV64ZBA-NEXT: slli a1, a1, 5
2264-
; RV64ZBA-NEXT: add a0, a0, a1
2265-
; RV64ZBA-NEXT: sh2add a0, a2, a0
2258+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2259+
; RV64ZBA-NEXT: sh2add a0, a1, a0
22662260
; RV64ZBA-NEXT: lw a0, 0(a0)
22672261
; RV64ZBA-NEXT: ret
22682262
%a = getelementptr inbounds [8 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2282,9 +2276,8 @@ define i64 @array_index_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
22822276
;
22832277
; RV64ZBA-LABEL: array_index_sh3_sh3:
22842278
; RV64ZBA: # %bb.0:
2285-
; RV64ZBA-NEXT: slli a1, a1, 6
2286-
; RV64ZBA-NEXT: add a0, a0, a1
2287-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2279+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2280+
; RV64ZBA-NEXT: sh3add a0, a1, a0
22882281
; RV64ZBA-NEXT: ld a0, 0(a0)
22892282
; RV64ZBA-NEXT: ret
22902283
%a = getelementptr inbounds [8 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2308,9 +2301,8 @@ define i64 @array_index_lshr_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
23082301
; RV64ZBA-LABEL: array_index_lshr_sh3_sh3:
23092302
; RV64ZBA: # %bb.0:
23102303
; RV64ZBA-NEXT: srli a1, a1, 58
2311-
; RV64ZBA-NEXT: slli a1, a1, 6
2312-
; RV64ZBA-NEXT: add a0, a0, a1
2313-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2304+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2305+
; RV64ZBA-NEXT: sh3add a0, a1, a0
23142306
; RV64ZBA-NEXT: ld a0, 0(a0)
23152307
; RV64ZBA-NEXT: ret
23162308
%shr = lshr i64 %idx1, 58

0 commit comments

Comments
 (0)