@@ -16,18 +16,22 @@ SPDX-License-Identifier: MIT
16
16
// /
17
17
// ===----------------------------------------------------------------------===//
18
18
19
-
20
19
#include " GenX.h"
20
+ #include " GenXSubtarget.h"
21
+ #include " GenXTargetMachine.h"
21
22
#include " GenXUtil.h"
23
+
22
24
#include " llvm/ADT/EquivalenceClasses.h"
23
25
#include " llvm/ADT/Statistic.h"
26
+ #include " llvm/CodeGen/TargetPassConfig.h"
24
27
#include " llvm/IR/IRBuilder.h"
25
28
#include " llvm/IR/InstIterator.h"
29
+ #include " llvm/InitializePasses.h"
26
30
#include " llvm/Pass.h"
27
31
28
32
#include " llvmWrapper/IR/DerivedTypes.h"
29
33
30
- #define DEBUG_TYPE " GENX_PROMOTE_PREDICATE "
34
+ #define DEBUG_TYPE " genx-promote-predicate "
31
35
32
36
using namespace llvm ;
33
37
using namespace genx ;
@@ -48,6 +52,7 @@ class GenXPromotePredicate : public FunctionPass {
48
52
bool runOnFunction (Function &F) override ;
49
53
StringRef getPassName () const override { return " GenXPromotePredicate" ; }
50
54
void getAnalysisUsage (AnalysisUsage &AU) const override {
55
+ AU.addRequired <TargetPassConfig>();
51
56
AU.setPreservesCFG ();
52
57
}
53
58
};
@@ -61,6 +66,7 @@ void initializeGenXPromotePredicatePass(PassRegistry &);
61
66
}
62
67
INITIALIZE_PASS_BEGIN (GenXPromotePredicate, " GenXPromotePredicate" ,
63
68
" GenXPromotePredicate" , false , false )
69
+ INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
64
70
INITIALIZE_PASS_END(GenXPromotePredicate, " GenXPromotePredicate" ,
65
71
" GenXPromotePredicate" , false , false )
66
72
@@ -138,8 +144,9 @@ static Value *promoteInstToScalar(Instruction *Inst) {
138
144
139
145
// Promote one predicate instruction to grf - promote all its operands and
140
146
// instruction itself, and then sink the result back to predicate.
141
- static Value *promoteInst (Instruction *Inst) {
142
- if (auto *VTy = dyn_cast<IGCLLVM::FixedVectorType>(Inst->getType ())) {
147
+ static Value *promoteInst (Instruction *Inst, bool AllowScalarPromotion) {
148
+ if (auto *VTy = dyn_cast<IGCLLVM::FixedVectorType>(Inst->getType ());
149
+ VTy && AllowScalarPromotion) {
143
150
IGC_ASSERT (VTy->isIntOrIntVectorTy (1 ));
144
151
auto Width = VTy->getNumElements ();
145
152
@@ -220,7 +227,8 @@ static void foldBitcast(BitCastInst *Cast) {
220
227
class PredicateWeb {
221
228
public:
222
229
template <class InputIt >
223
- PredicateWeb (InputIt first, InputIt last) : Web(first, last) {}
230
+ PredicateWeb (InputIt First, InputIt Last, bool AllowScalar)
231
+ : Web(First, Last), AllowScalarPromotion(AllowScalar) {}
224
232
void print (llvm::raw_ostream &O) const {
225
233
for (auto Inst : Web)
226
234
O << *Inst << ' \n ' ;
@@ -236,7 +244,7 @@ class PredicateWeb {
236
244
// Do promotion.
237
245
SmallVector<Instruction *, 8 > Worklist;
238
246
for (auto *Inst : Web) {
239
- auto *PromotedInst = promoteInst (Inst);
247
+ auto *PromotedInst = promoteInst (Inst, AllowScalarPromotion );
240
248
241
249
if (isa<TruncInst>(PromotedInst) || isa<BitCastInst>(PromotedInst))
242
250
Worklist.push_back (cast<Instruction>(PromotedInst));
@@ -254,6 +262,7 @@ class PredicateWeb {
254
262
255
263
private:
256
264
SmallPtrSet<Instruction *, 16 > Web;
265
+ bool AllowScalarPromotion;
257
266
};
258
267
259
268
constexpr const char IdxMDName[] = " pred.index" ;
@@ -273,6 +282,11 @@ struct Comparator {
273
282
};
274
283
275
284
bool GenXPromotePredicate::runOnFunction (Function &F) {
285
+ auto &ST = getAnalysis<TargetPassConfig>()
286
+ .getTM <GenXTargetMachine>()
287
+ .getGenXSubtarget ();
288
+ bool AllowScalarPromotion = !ST.hasFusedEU ();
289
+
276
290
// Put every predicate instruction into its own equivalence class.
277
291
long Idx = 0 ;
278
292
llvm::EquivalenceClasses<Instruction *, Comparator> PredicateWebs;
@@ -303,7 +317,8 @@ bool GenXPromotePredicate::runOnFunction(Function &F) {
303
317
for (auto I = PredicateWebs.begin (), E = PredicateWebs.end (); I != E; ++I) {
304
318
if (!I->isLeader ())
305
319
continue ;
306
- PredicateWeb Web (PredicateWebs.member_begin (I), PredicateWebs.member_end ());
320
+ PredicateWeb Web (PredicateWebs.member_begin (I), PredicateWebs.member_end (),
321
+ AllowScalarPromotion);
307
322
LLVM_DEBUG (dbgs () << " Predicate web:\n " ; Web.dump ());
308
323
++NumCollectedPredicateWebs;
309
324
if (!Web.isBeneficialToPromote ())
0 commit comments