Skip to content

Commit 5cb4322

Browse files
committed
[AMDGPU] Filter candidates of LiveRegOptimizer for profitable cases
It is known that for vector whose element fits in i16 will be split and scalarized in SelectionDag's type legalizer (see SIISelLowering::getPreferredVectorAction). LRO attempts to undo the scalarizing of vectors across basic block boundary and shoehorn Values in VGPRs. LRO is beneficial for operations that natively work on illegal vector types to prevent flip-flopping between unpacked and packed. If we know that operations on vector will be split and scalarized, then we don't want to shoehorn them back to packed VGPR. Operations that we know to work natively on illegal vector types usually come in the form of intrinsics (MFMA, DOT8), buffer store, shuffle, insert/extract element, phi nodes to name a few.
1 parent 55ae118 commit 5cb4322

File tree

7 files changed

+425
-169
lines changed

7 files changed

+425
-169
lines changed

llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp

Lines changed: 220 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
#include "AMDGPU.h"
1616
#include "AMDGPUTargetMachine.h"
17+
#include "AMDGPUTargetTransformInfo.h"
1718
#include "llvm/Analysis/AssumptionCache.h"
1819
#include "llvm/Analysis/UniformityAnalysis.h"
1920
#include "llvm/Analysis/ValueTracking.h"
2021
#include "llvm/CodeGen/TargetPassConfig.h"
2122
#include "llvm/IR/IRBuilder.h"
2223
#include "llvm/IR/InstVisitor.h"
24+
#include "llvm/IR/IntrinsicsAMDGPU.h"
2325
#include "llvm/Support/CommandLine.h"
2426
#include "llvm/Support/KnownBits.h"
2527
#include "llvm/Transforms/Utils/Local.h"
@@ -45,6 +47,7 @@ class AMDGPULateCodeGenPrepare
4547
Function &F;
4648
const DataLayout &DL;
4749
const GCNSubtarget &ST;
50+
const TargetTransformInfo &TTI;
4851

4952
AssumptionCache *const AC;
5053
UniformityInfo &UA;
@@ -53,8 +56,9 @@ class AMDGPULateCodeGenPrepare
5356

5457
public:
5558
AMDGPULateCodeGenPrepare(Function &F, const GCNSubtarget &ST,
56-
AssumptionCache *AC, UniformityInfo &UA)
57-
: F(F), DL(F.getDataLayout()), ST(ST), AC(AC), UA(UA) {}
59+
const TargetTransformInfo &TTI, AssumptionCache *AC,
60+
UniformityInfo &UA)
61+
: F(F), DL(F.getDataLayout()), ST(ST), TTI(TTI), AC(AC), UA(UA) {}
5862
bool run();
5963
bool visitInstruction(Instruction &) { return false; }
6064

@@ -75,6 +79,8 @@ class LiveRegOptimizer {
7579
Module &Mod;
7680
const DataLayout &DL;
7781
const GCNSubtarget &ST;
82+
const TargetTransformInfo &TTI;
83+
7884
/// The scalar type to convert to
7985
Type *const ConvertToScalar;
8086
/// The set of visited Instructions
@@ -125,8 +131,210 @@ class LiveRegOptimizer {
125131
return LK.first != TargetLoweringBase::TypeLegal;
126132
}
127133

128-
LiveRegOptimizer(Module &Mod, const GCNSubtarget &ST)
129-
: Mod(Mod), DL(Mod.getDataLayout()), ST(ST),
134+
// Filtering based on operation or its cost.
135+
// If an operation incurs high enough cost or natively work on
136+
// vector of illegal type, ie. v2i8, then it makes sense to try
137+
// to coerce them as packed VGPR across BB.
138+
bool shouldReplaceByOp(Instruction *II) {
139+
static const int SCALARIZE_INST_COST = 2;
140+
static const int LRO_COST_THRES = 12;
141+
142+
// Ignore pseudos
143+
if (II->isDebugOrPseudoInst())
144+
return false;
145+
146+
// Instruction Cost
147+
auto Cost = TTI.getInstructionCost(
148+
II, TargetTransformInfo::TargetCostKind::TCK_SizeAndLatency);
149+
if (const auto *Def = II->getOperand(0)) {
150+
if (const auto *DefTy = dyn_cast<FixedVectorType>(Def->getType())) {
151+
const auto *ElTy = dyn_cast<IntegerType>(DefTy->getElementType());
152+
// Assume vNi8 and vNi16 will be scalarized.
153+
if (ElTy && ElTy->getBitWidth() <= 16) {
154+
const auto ElCount = DefTy->getElementCount().getFixedValue();
155+
Cost += SCALARIZE_INST_COST * ElCount;
156+
}
157+
}
158+
}
159+
LLVM_DEBUG(dbgs() << "shouldReplaceByOp: " << *II << " Cost=" << Cost
160+
<< '\n';);
161+
if (Cost >= LRO_COST_THRES)
162+
return true;
163+
164+
if (isOpLegal(II))
165+
return true;
166+
167+
return false;
168+
}
169+
170+
/// Check if intrinsic natively operates on 8-bit or 16-bit
171+
bool isNativeIntrinsic(Intrinsic::ID ID) {
172+
switch (ID) {
173+
case Intrinsic::amdgcn_dot4_f32_fp8_bf8:
174+
case Intrinsic::amdgcn_dot4_f32_bf8_fp8:
175+
case Intrinsic::amdgcn_dot4_f32_fp8_fp8:
176+
case Intrinsic::amdgcn_dot4_f32_bf8_bf8:
177+
case Intrinsic::amdgcn_fdot2_f16_f16:
178+
case Intrinsic::amdgcn_fdot2:
179+
case Intrinsic::amdgcn_sdot4:
180+
case Intrinsic::amdgcn_sdot2:
181+
case Intrinsic::amdgcn_sdot8:
182+
case Intrinsic::amdgcn_udot2:
183+
case Intrinsic::amdgcn_udot4:
184+
case Intrinsic::amdgcn_udot8:
185+
case Intrinsic::amdgcn_sudot4:
186+
case Intrinsic::amdgcn_sudot8:
187+
case Intrinsic::amdgcn_mfma_f32_4x4x1f32:
188+
case Intrinsic::amdgcn_mfma_f32_16x16x1f32:
189+
case Intrinsic::amdgcn_mfma_f32_16x16x4f32:
190+
case Intrinsic::amdgcn_mfma_f32_32x32x1f32:
191+
case Intrinsic::amdgcn_mfma_f32_32x32x2f32:
192+
case Intrinsic::amdgcn_mfma_f32_4x4x4f16:
193+
case Intrinsic::amdgcn_mfma_i32_4x4x4i8:
194+
case Intrinsic::amdgcn_mfma_f32_16x16x4f16:
195+
case Intrinsic::amdgcn_mfma_f32_16x16x16f16:
196+
case Intrinsic::amdgcn_mfma_i32_16x16x4i8:
197+
case Intrinsic::amdgcn_mfma_f32_32x32x4f16:
198+
case Intrinsic::amdgcn_mfma_f32_32x32x8f16:
199+
case Intrinsic::amdgcn_mfma_i32_32x32x4i8:
200+
case Intrinsic::amdgcn_mfma_i32_16x16x16i8:
201+
case Intrinsic::amdgcn_mfma_i32_32x32x8i8:
202+
case Intrinsic::amdgcn_mfma_f32_4x4x2bf16:
203+
case Intrinsic::amdgcn_mfma_f32_16x16x2bf16:
204+
case Intrinsic::amdgcn_mfma_f32_16x16x8bf16:
205+
case Intrinsic::amdgcn_mfma_f32_32x32x2bf16:
206+
case Intrinsic::amdgcn_mfma_f32_32x32x4bf16:
207+
case Intrinsic::amdgcn_mfma_f32_16x16x32_f16:
208+
case Intrinsic::amdgcn_mfma_f32_32x32x16_f16:
209+
case Intrinsic::amdgcn_mfma_i32_16x16x64_i8:
210+
case Intrinsic::amdgcn_mfma_i32_32x32x32_i8:
211+
case Intrinsic::amdgcn_mfma_f32_32x32x4bf16_1k:
212+
case Intrinsic::amdgcn_mfma_f32_16x16x4bf16_1k:
213+
case Intrinsic::amdgcn_mfma_f32_4x4x4bf16_1k:
214+
case Intrinsic::amdgcn_mfma_f32_32x32x8bf16_1k:
215+
case Intrinsic::amdgcn_mfma_f32_16x16x16bf16_1k:
216+
case Intrinsic::amdgcn_mfma_f64_16x16x4f64:
217+
case Intrinsic::amdgcn_mfma_f64_4x4x4f64:
218+
case Intrinsic::amdgcn_mfma_i32_32x32x16_i8:
219+
case Intrinsic::amdgcn_mfma_i32_16x16x32_i8:
220+
case Intrinsic::amdgcn_mfma_f32_16x16x8_xf32:
221+
case Intrinsic::amdgcn_mfma_f32_32x32x4_xf32:
222+
case Intrinsic::amdgcn_mfma_f32_16x16x32_bf8_bf8:
223+
case Intrinsic::amdgcn_mfma_f32_16x16x32_bf8_fp8:
224+
case Intrinsic::amdgcn_mfma_f32_16x16x32_fp8_bf8:
225+
case Intrinsic::amdgcn_mfma_f32_16x16x32_fp8_fp8:
226+
case Intrinsic::amdgcn_mfma_f32_32x32x16_bf8_bf8:
227+
case Intrinsic::amdgcn_mfma_f32_32x32x16_bf8_fp8:
228+
case Intrinsic::amdgcn_mfma_f32_32x32x16_fp8_bf8:
229+
case Intrinsic::amdgcn_mfma_f32_32x32x16_fp8_fp8:
230+
case Intrinsic::amdgcn_smfmac_f32_16x16x32_f16:
231+
case Intrinsic::amdgcn_smfmac_f32_32x32x16_f16:
232+
case Intrinsic::amdgcn_smfmac_f32_16x16x32_bf16:
233+
case Intrinsic::amdgcn_smfmac_f32_32x32x16_bf16:
234+
case Intrinsic::amdgcn_smfmac_i32_16x16x64_i8:
235+
case Intrinsic::amdgcn_smfmac_i32_32x32x32_i8:
236+
case Intrinsic::amdgcn_smfmac_f32_16x16x64_bf8_bf8:
237+
case Intrinsic::amdgcn_smfmac_f32_16x16x64_bf8_fp8:
238+
case Intrinsic::amdgcn_smfmac_f32_16x16x64_fp8_bf8:
239+
case Intrinsic::amdgcn_smfmac_f32_16x16x64_fp8_fp8:
240+
case Intrinsic::amdgcn_smfmac_f32_32x32x32_bf8_bf8:
241+
case Intrinsic::amdgcn_smfmac_f32_32x32x32_bf8_fp8:
242+
case Intrinsic::amdgcn_smfmac_f32_32x32x32_fp8_bf8:
243+
case Intrinsic::amdgcn_smfmac_f32_32x32x32_fp8_fp8:
244+
case Intrinsic::amdgcn_smfmac_f32_16x16x64_f16:
245+
case Intrinsic::amdgcn_smfmac_f32_32x32x32_f16:
246+
case Intrinsic::amdgcn_smfmac_i32_16x16x128_i8:
247+
case Intrinsic::amdgcn_smfmac_i32_32x32x64_i8:
248+
case Intrinsic::amdgcn_smfmac_f32_16x16x128_bf8_bf8:
249+
case Intrinsic::amdgcn_smfmac_f32_16x16x128_bf8_fp8:
250+
case Intrinsic::amdgcn_smfmac_f32_16x16x128_fp8_bf8:
251+
case Intrinsic::amdgcn_smfmac_f32_16x16x128_fp8_fp8:
252+
case Intrinsic::amdgcn_smfmac_f32_32x32x64_bf8_bf8:
253+
case Intrinsic::amdgcn_smfmac_f32_32x32x64_bf8_fp8:
254+
case Intrinsic::amdgcn_smfmac_f32_32x32x64_fp8_bf8:
255+
case Intrinsic::amdgcn_smfmac_f32_32x32x64_fp8_fp8:
256+
case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
257+
case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4:
258+
case Intrinsic::amdgcn_wmma_f32_16x16x16_f16:
259+
case Intrinsic::amdgcn_wmma_f32_16x16x16_bf16:
260+
case Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8:
261+
case Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8:
262+
case Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8:
263+
case Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8:
264+
case Intrinsic::amdgcn_wmma_f16_16x16x16_f16:
265+
case Intrinsic::amdgcn_swmmac_f32_16x16x32_f16:
266+
case Intrinsic::amdgcn_swmmac_f16_16x16x32_f16:
267+
case Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16:
268+
case Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16:
269+
case Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16:
270+
case Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8:
271+
case Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8:
272+
case Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8:
273+
case Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8:
274+
case Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied:
275+
case Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied:
276+
case Intrinsic::amdgcn_wmma_i32_16x16x16_iu8:
277+
case Intrinsic::amdgcn_wmma_i32_16x16x16_iu4:
278+
case Intrinsic::amdgcn_wmma_i32_16x16x32_iu4:
279+
case Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8:
280+
case Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4:
281+
case Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4:
282+
return true;
283+
default:
284+
return false;
285+
}
286+
}
287+
288+
bool isOpLegal(Instruction *I) {
289+
Type *T = I->getType();
290+
if (!TTI.isTypeLegal(T)) {
291+
if (const auto Intr = dyn_cast<IntrinsicInst>(I)) {
292+
Intrinsic::ID ID = Intr->getIntrinsicID();
293+
if (isNativeIntrinsic(ID))
294+
return true;
295+
}
296+
// Stores
297+
if (isa<StoreInst>(I))
298+
return true;
299+
return false;
300+
}
301+
return true;
302+
}
303+
304+
bool isCoercionProfitable(Instruction *II) {
305+
if (shouldReplaceByOp(II))
306+
return true;
307+
308+
// Look through Users
309+
bool Profitable = false;
310+
SmallPtrSet<Instruction *, 4> CVisited;
311+
SmallVector<Instruction *, 4> UserList;
312+
for (User *V : II->users())
313+
if (auto *UseInst = dyn_cast<Instruction>(V))
314+
UserList.push_back(UseInst);
315+
316+
while (!UserList.empty() && !Profitable) {
317+
auto CII = UserList.pop_back_val();
318+
if (!CVisited.insert(II).second)
319+
continue;
320+
321+
if (isa<PHINode>(CII) || isa<ShuffleVectorInst>(CII) ||
322+
isa<InsertElementInst>(CII) || isa<ExtractElementInst>(CII))
323+
for (User *V : CII->users())
324+
if (auto *UseInst = dyn_cast<Instruction>(V))
325+
UserList.push_back(UseInst);
326+
327+
if (CII->getParent() == II->getParent())
328+
continue;
329+
330+
Profitable = shouldReplaceByOp(CII);
331+
}
332+
return Profitable;
333+
}
334+
335+
LiveRegOptimizer(Module &Mod, const GCNSubtarget &ST,
336+
const TargetTransformInfo &TTI)
337+
: Mod(Mod), DL(Mod.getDataLayout()), ST(ST), TTI(TTI),
130338
ConvertToScalar(Type::getInt32Ty(Mod.getContext())) {}
131339
};
132340

@@ -140,7 +348,7 @@ bool AMDGPULateCodeGenPrepare::run() {
140348
// vectors to equivalent vectors of legal type (which are converted back
141349
// before uses in subsequent blocks), to pack the bits into fewer physical
142350
// registers (used in CopyToReg/CopyFromReg pairs).
143-
LiveRegOptimizer LRO(*F.getParent(), ST);
351+
LiveRegOptimizer LRO(*F.getParent(), ST, TTI);
144352

145353
bool Changed = false;
146354

@@ -259,6 +467,9 @@ bool LiveRegOptimizer::optimizeLiveType(
259467
if (!shouldReplace(II->getType()))
260468
continue;
261469

470+
if (!isCoercionProfitable(II))
471+
continue;
472+
262473
if (PHINode *Phi = dyn_cast<PHINode>(II)) {
263474
PhiNodes.insert(Phi);
264475
// Collect all the incoming values of problematic PHI nodes.
@@ -478,11 +689,12 @@ bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
478689
PreservedAnalyses
479690
AMDGPULateCodeGenPreparePass::run(Function &F, FunctionAnalysisManager &FAM) {
480691
const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
692+
const TargetTransformInfo &TTI = TM.getTargetTransformInfo(F);
481693

482694
AssumptionCache &AC = FAM.getResult<AssumptionAnalysis>(F);
483695
UniformityInfo &UI = FAM.getResult<UniformityInfoAnalysis>(F);
484696

485-
bool Changed = AMDGPULateCodeGenPrepare(F, ST, &AC, UI).run();
697+
bool Changed = AMDGPULateCodeGenPrepare(F, ST, TTI, &AC, UI).run();
486698

487699
if (!Changed)
488700
return PreservedAnalyses::all();
@@ -518,13 +730,14 @@ bool AMDGPULateCodeGenPrepareLegacy::runOnFunction(Function &F) {
518730
const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
519731
const TargetMachine &TM = TPC.getTM<TargetMachine>();
520732
const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
733+
const TargetTransformInfo &TTI = TM.getTargetTransformInfo(F);
521734

522735
AssumptionCache &AC =
523736
getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
524737
UniformityInfo &UI =
525738
getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
526739

527-
return AMDGPULateCodeGenPrepare(F, ST, &AC, UI).run();
740+
return AMDGPULateCodeGenPrepare(F, ST, TTI, &AC, UI).run();
528741
}
529742

530743
INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepareLegacy, DEBUG_TYPE,

0 commit comments

Comments
 (0)