Skip to content

Commit 695bd0a

Browse files
committed
[RISCV] Add a MIR pass to reassociate shXadd, add, and slli to form more shXadd.
This reassociates patterns like (sh3add Z, (add X, (slli Y, 6))) into (sh3add (sh3add Y, Z), X). This improves a pattern that occurs in 531.deepsjeng_r. Reducing the dynamic instruction count by 0.5%. This may be possible to improve in SelectionDAG, but given the special cases around shXadd formation, it's not obvious it can be done in a robust way without adding multiple special cases. I've used a GEP with 2 indices because that mostly closely resembles the motivating case. Most of the test cases are the simplest GEP case. One test has a logical right shift on an index which is closer to the deepsjeng code. This requires special handling in isel to reverse a DAGCombiner canonicalization that turns a pair of shifts into (srl (and X, C1), C2). See also llvm#85734 which had a hacky version of a similar optimization.
1 parent eec4299 commit 695bd0a

File tree

6 files changed

+174
-24
lines changed

6 files changed

+174
-24
lines changed

llvm/lib/Target/RISCV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ add_llvm_target(RISCVCodeGen
4646
RISCVMachineFunctionInfo.cpp
4747
RISCVMergeBaseOffset.cpp
4848
RISCVOptWInstrs.cpp
49+
RISCVOptZba.cpp
4950
RISCVPostRAExpandPseudoInsts.cpp
5051
RISCVRedundantCopyElimination.cpp
5152
RISCVMoveMerger.cpp

llvm/lib/Target/RISCV/RISCV.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ void initializeRISCVFoldMasksPass(PassRegistry &);
4646
FunctionPass *createRISCVOptWInstrsPass();
4747
void initializeRISCVOptWInstrsPass(PassRegistry &);
4848

49+
FunctionPass *createRISCVOptZbaPass();
50+
void initializeRISCVOptZbaPass(PassRegistry &);
51+
4952
FunctionPass *createRISCVMergeBaseOffsetOptPass();
5053
void initializeRISCVMergeBaseOffsetOptPass(PassRegistry &);
5154

llvm/lib/Target/RISCV/RISCVOptZba.cpp

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
//===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
//
9+
// This pass attempts to reassociate expressions like
10+
// (sh3add Z, (add X, (slli Y, 6)))
11+
// To
12+
// (sh3add (sh3add Y, Z), X)
13+
//
14+
// This reduces number of instructions needing by spreading the shift amount
15+
// across to 2 sh3adds. This can be generalized to sh1add/sh2add and other
16+
// shift amounts.
17+
//
18+
// TODO: We can also support slli.uw by using shXadd.uw or the inner shXadd.
19+
20+
#include "RISCV.h"
21+
#include "RISCVSubtarget.h"
22+
#include "llvm/CodeGen/MachineFunctionPass.h"
23+
24+
using namespace llvm;
25+
26+
#define DEBUG_TYPE "riscv-opt-zba"
27+
#define RISCV_OPT_ZBA_NAME "RISC-V Optimize Zba"
28+
29+
namespace {
30+
31+
class RISCVOptZba : public MachineFunctionPass {
32+
public:
33+
static char ID;
34+
35+
RISCVOptZba() : MachineFunctionPass(ID) {}
36+
37+
bool runOnMachineFunction(MachineFunction &MF) override;
38+
39+
void getAnalysisUsage(AnalysisUsage &AU) const override {
40+
AU.setPreservesCFG();
41+
MachineFunctionPass::getAnalysisUsage(AU);
42+
}
43+
44+
StringRef getPassName() const override { return RISCV_OPT_ZBA_NAME; }
45+
};
46+
47+
} // end anonymous namespace
48+
49+
char RISCVOptZba::ID = 0;
50+
INITIALIZE_PASS(RISCVOptZba, DEBUG_TYPE, RISCV_OPT_ZBA_NAME, false, false)
51+
52+
FunctionPass *llvm::createRISCVOptZbaPass() { return new RISCVOptZba(); }
53+
54+
static bool findShift(Register Reg, const MachineBasicBlock &MBB,
55+
MachineInstr *&Shift, const MachineRegisterInfo &MRI) {
56+
if (!Reg.isVirtual())
57+
return false;
58+
59+
Shift = MRI.getVRegDef(Reg);
60+
if (!Shift || Shift->getOpcode() != RISCV::SLLI ||
61+
Shift->getParent() != &MBB || !MRI.hasOneNonDBGUse(Reg))
62+
return false;
63+
64+
return true;
65+
}
66+
67+
bool RISCVOptZba::runOnMachineFunction(MachineFunction &MF) {
68+
if (skipFunction(MF.getFunction()))
69+
return false;
70+
71+
MachineRegisterInfo &MRI = MF.getRegInfo();
72+
const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
73+
const RISCVInstrInfo &TII = *ST.getInstrInfo();
74+
75+
if (!ST.hasStdExtZba())
76+
return false;
77+
78+
bool MadeChange = true;
79+
80+
for (MachineBasicBlock &MBB : MF) {
81+
for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
82+
unsigned OuterShiftAmt;
83+
switch (MI.getOpcode()) {
84+
default:
85+
continue;
86+
case RISCV::SH1ADD: OuterShiftAmt = 1; break;
87+
case RISCV::SH2ADD: OuterShiftAmt = 2; break;
88+
case RISCV::SH3ADD: OuterShiftAmt = 3; break;
89+
}
90+
91+
// Second operand must be virtual.
92+
Register UnshiftedReg = MI.getOperand(2).getReg();
93+
if (!UnshiftedReg.isVirtual())
94+
continue;
95+
96+
MachineInstr *Add = MRI.getVRegDef(UnshiftedReg);
97+
if (!Add || Add->getOpcode() != RISCV::ADD || Add->getParent() != &MBB ||
98+
!MRI.hasOneNonDBGUse(UnshiftedReg))
99+
continue;
100+
101+
Register AddReg0 = Add->getOperand(1).getReg();
102+
Register AddReg1 = Add->getOperand(2).getReg();
103+
104+
MachineInstr *InnerShift;
105+
Register X;
106+
if (findShift(AddReg0, MBB, InnerShift, MRI))
107+
X = AddReg1;
108+
else if (findShift(AddReg1, MBB, InnerShift, MRI))
109+
X = AddReg0;
110+
else
111+
continue;
112+
113+
unsigned InnerShiftAmt = InnerShift->getOperand(2).getImm();
114+
115+
// The inner shift amount must be at least as large as the outer shift
116+
// amount.
117+
if (OuterShiftAmt > InnerShiftAmt)
118+
continue;
119+
120+
unsigned InnerOpc;
121+
switch (InnerShiftAmt - OuterShiftAmt) {
122+
default:
123+
continue;
124+
case 0: InnerOpc = RISCV::ADD; break;
125+
case 1: InnerOpc = RISCV::SH1ADD; break;
126+
case 2: InnerOpc = RISCV::SH2ADD; break;
127+
case 3: InnerOpc = RISCV::SH3ADD; break;
128+
}
129+
130+
Register Y = InnerShift->getOperand(1).getReg();
131+
Register Z = MI.getOperand(1).getReg();
132+
133+
Register NewReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
134+
BuildMI(MBB, MI, MI.getDebugLoc(), TII.get(InnerOpc), NewReg)
135+
.addReg(Y)
136+
.addReg(Z);
137+
BuildMI(MBB, MI, MI.getDebugLoc(), TII.get(MI.getOpcode()),
138+
MI.getOperand(0).getReg())
139+
.addReg(NewReg)
140+
.addReg(X);
141+
142+
MI.eraseFromParent();
143+
Add->eraseFromParent();
144+
InnerShift->eraseFromParent();
145+
MadeChange = true;
146+
}
147+
}
148+
149+
return MadeChange;
150+
}

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() {
117117
initializeRISCVPostRAExpandPseudoPass(*PR);
118118
initializeRISCVMergeBaseOffsetOptPass(*PR);
119119
initializeRISCVOptWInstrsPass(*PR);
120+
initializeRISCVOptZbaPass(*PR);
120121
initializeRISCVPreRAExpandPseudoPass(*PR);
121122
initializeRISCVExpandPseudoPass(*PR);
122123
initializeRISCVFoldMasksPass(*PR);
@@ -531,6 +532,8 @@ void RISCVPassConfig::addMachineSSAOptimization() {
531532
if (EnableMachineCombiner)
532533
addPass(&MachineCombinerID);
533534

535+
addPass(createRISCVOptZbaPass());
536+
534537
if (TM->getTargetTriple().isRISCV64()) {
535538
addPass(createRISCVOptWInstrsPass());
536539
}

llvm/test/CodeGen/RISCV/O3-pipeline.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
; CHECK-NEXT: Machine Trace Metrics
113113
; CHECK-NEXT: Lazy Machine Block Frequency Analysis
114114
; CHECK-NEXT: Machine InstCombiner
115+
; CHECK-NEXT: RISC-V Optimize Zba
115116
; RV64-NEXT: RISC-V Optimize W Instructions
116117
; CHECK-NEXT: RISC-V Pre-RA pseudo instruction expansion pass
117118
; CHECK-NEXT: RISC-V Merge Base Offset

llvm/test/CodeGen/RISCV/rv64zba.ll

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,9 +1404,8 @@ define i64 @sh6_sh3_add2(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
14041404
;
14051405
; RV64ZBA-LABEL: sh6_sh3_add2:
14061406
; RV64ZBA: # %bb.0: # %entry
1407-
; RV64ZBA-NEXT: slli a1, a1, 6
1408-
; RV64ZBA-NEXT: add a0, a1, a0
1409-
; RV64ZBA-NEXT: sh3add a0, a2, a0
1407+
; RV64ZBA-NEXT: sh3add a1, a1, a2
1408+
; RV64ZBA-NEXT: sh3add a0, a1, a0
14101409
; RV64ZBA-NEXT: ret
14111410
entry:
14121411
%shl = shl i64 %z, 3
@@ -2111,9 +2110,8 @@ define i64 @array_index_sh1_sh3(ptr %p, i64 %idx1, i64 %idx2) {
21112110
;
21122111
; RV64ZBA-LABEL: array_index_sh1_sh3:
21132112
; RV64ZBA: # %bb.0:
2114-
; RV64ZBA-NEXT: slli a1, a1, 4
2115-
; RV64ZBA-NEXT: add a0, a0, a1
2116-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2113+
; RV64ZBA-NEXT: sh1add a1, a1, a2
2114+
; RV64ZBA-NEXT: sh3add a0, a1, a0
21172115
; RV64ZBA-NEXT: ld a0, 0(a0)
21182116
; RV64ZBA-NEXT: ret
21192117
%a = getelementptr inbounds [2 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2174,9 +2172,8 @@ define i32 @array_index_sh2_sh2(ptr %p, i64 %idx1, i64 %idx2) {
21742172
;
21752173
; RV64ZBA-LABEL: array_index_sh2_sh2:
21762174
; RV64ZBA: # %bb.0:
2177-
; RV64ZBA-NEXT: slli a1, a1, 4
2178-
; RV64ZBA-NEXT: add a0, a0, a1
2179-
; RV64ZBA-NEXT: sh2add a0, a2, a0
2175+
; RV64ZBA-NEXT: sh2add a1, a1, a2
2176+
; RV64ZBA-NEXT: sh2add a0, a1, a0
21802177
; RV64ZBA-NEXT: lw a0, 0(a0)
21812178
; RV64ZBA-NEXT: ret
21822179
%a = getelementptr inbounds [4 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2196,9 +2193,8 @@ define i64 @array_index_sh2_sh3(ptr %p, i64 %idx1, i64 %idx2) {
21962193
;
21972194
; RV64ZBA-LABEL: array_index_sh2_sh3:
21982195
; RV64ZBA: # %bb.0:
2199-
; RV64ZBA-NEXT: slli a1, a1, 5
2200-
; RV64ZBA-NEXT: add a0, a0, a1
2201-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2196+
; RV64ZBA-NEXT: sh2add a1, a1, a2
2197+
; RV64ZBA-NEXT: sh3add a0, a1, a0
22022198
; RV64ZBA-NEXT: ld a0, 0(a0)
22032199
; RV64ZBA-NEXT: ret
22042200
%a = getelementptr inbounds [4 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2238,9 +2234,8 @@ define i16 @array_index_sh3_sh1(ptr %p, i64 %idx1, i64 %idx2) {
22382234
;
22392235
; RV64ZBA-LABEL: array_index_sh3_sh1:
22402236
; RV64ZBA: # %bb.0:
2241-
; RV64ZBA-NEXT: slli a1, a1, 4
2242-
; RV64ZBA-NEXT: add a0, a0, a1
2243-
; RV64ZBA-NEXT: sh1add a0, a2, a0
2237+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2238+
; RV64ZBA-NEXT: sh1add a0, a1, a0
22442239
; RV64ZBA-NEXT: lh a0, 0(a0)
22452240
; RV64ZBA-NEXT: ret
22462241
%a = getelementptr inbounds [8 x i16], ptr %p, i64 %idx1, i64 %idx2
@@ -2260,9 +2255,8 @@ define i32 @array_index_sh3_sh2(ptr %p, i64 %idx1, i64 %idx2) {
22602255
;
22612256
; RV64ZBA-LABEL: array_index_sh3_sh2:
22622257
; RV64ZBA: # %bb.0:
2263-
; RV64ZBA-NEXT: slli a1, a1, 5
2264-
; RV64ZBA-NEXT: add a0, a0, a1
2265-
; RV64ZBA-NEXT: sh2add a0, a2, a0
2258+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2259+
; RV64ZBA-NEXT: sh2add a0, a1, a0
22662260
; RV64ZBA-NEXT: lw a0, 0(a0)
22672261
; RV64ZBA-NEXT: ret
22682262
%a = getelementptr inbounds [8 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2282,9 +2276,8 @@ define i64 @array_index_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
22822276
;
22832277
; RV64ZBA-LABEL: array_index_sh3_sh3:
22842278
; RV64ZBA: # %bb.0:
2285-
; RV64ZBA-NEXT: slli a1, a1, 6
2286-
; RV64ZBA-NEXT: add a0, a0, a1
2287-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2279+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2280+
; RV64ZBA-NEXT: sh3add a0, a1, a0
22882281
; RV64ZBA-NEXT: ld a0, 0(a0)
22892282
; RV64ZBA-NEXT: ret
22902283
%a = getelementptr inbounds [8 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2308,9 +2301,8 @@ define i64 @array_index_lshr_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
23082301
; RV64ZBA-LABEL: array_index_lshr_sh3_sh3:
23092302
; RV64ZBA: # %bb.0:
23102303
; RV64ZBA-NEXT: srli a1, a1, 58
2311-
; RV64ZBA-NEXT: slli a1, a1, 6
2312-
; RV64ZBA-NEXT: add a0, a0, a1
2313-
; RV64ZBA-NEXT: sh3add a0, a2, a0
2304+
; RV64ZBA-NEXT: sh3add a1, a1, a2
2305+
; RV64ZBA-NEXT: sh3add a0, a1, a0
23142306
; RV64ZBA-NEXT: ld a0, 0(a0)
23152307
; RV64ZBA-NEXT: ret
23162308
%shr = lshr i64 %idx1, 58

0 commit comments

Comments
 (0)