10
10
#include " DirectX.h"
11
11
#include " llvm/ADT/PostOrderIterator.h"
12
12
#include " llvm/ADT/STLExtras.h"
13
+ #include " llvm/IR/DerivedTypes.h"
13
14
#include " llvm/IR/GlobalVariable.h"
14
15
#include " llvm/IR/IRBuilder.h"
15
16
#include " llvm/IR/InstVisitor.h"
@@ -40,9 +41,10 @@ static bool findAndReplaceVectors(Module &M);
40
41
class DataScalarizerVisitor : public InstVisitor <DataScalarizerVisitor, bool > {
41
42
public:
42
43
DataScalarizerVisitor () : GlobalMap() {}
43
- bool visit (Instruction &I );
44
+ bool visit (Function &F );
44
45
// InstVisitor methods. They return true if the instruction was scalarized,
45
46
// false if nothing changed.
47
+ bool visitAllocaInst (AllocaInst &AI);
46
48
bool visitInstruction (Instruction &I) { return false ; }
47
49
bool visitSelectInst (SelectInst &SI) { return false ; }
48
50
bool visitICmpInst (ICmpInst &ICI) { return false ; }
@@ -67,9 +69,14 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
67
69
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
68
70
};
69
71
70
- bool DataScalarizerVisitor::visit (Instruction &I) {
71
- assert (!GlobalMap.empty ());
72
- return InstVisitor::visit (I);
72
+ bool DataScalarizerVisitor::visit (Function &F) {
73
+ bool MadeChange = false ;
74
+ ReversePostOrderTraversal<Function *> RPOT (&F);
75
+ for (BasicBlock *BB : make_early_inc_range (RPOT)) {
76
+ for (Instruction &I : make_early_inc_range (*BB))
77
+ MadeChange |= InstVisitor::visit (I);
78
+ }
79
+ return MadeChange;
73
80
}
74
81
75
82
GlobalVariable *
@@ -83,6 +90,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
83
90
return nullptr ; // Not found
84
91
}
85
92
93
+ // Recursively creates an array version of the given vector type.
94
+ static Type *replaceVectorWithArray (Type *T, LLVMContext &Ctx) {
95
+ if (auto *VecTy = dyn_cast<VectorType>(T))
96
+ return ArrayType::get (VecTy->getElementType (),
97
+ dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
98
+ if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
99
+ Type *NewElementType =
100
+ replaceVectorWithArray (ArrayTy->getElementType (), Ctx);
101
+ return ArrayType::get (NewElementType, ArrayTy->getNumElements ());
102
+ }
103
+ // If it's not a vector or array, return the original type.
104
+ return T;
105
+ }
106
+
107
+ static bool isArrayOfVectors (Type *T) {
108
+ if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
109
+ return isa<VectorType>(ArrType->getElementType ());
110
+ return false ;
111
+ }
112
+
113
+ bool DataScalarizerVisitor::visitAllocaInst (AllocaInst &AI) {
114
+ if (!isArrayOfVectors (AI.getAllocatedType ()))
115
+ return false ;
116
+
117
+ ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType ());
118
+ IRBuilder<> Builder (&AI);
119
+ LLVMContext &Ctx = AI.getContext ();
120
+ Type *NewType = replaceVectorWithArray (ArrType, Ctx);
121
+ AllocaInst *ArrAlloca =
122
+ Builder.CreateAlloca (NewType, nullptr , AI.getName () + " .scalarize" );
123
+ ArrAlloca->setAlignment (AI.getAlign ());
124
+ AI.replaceAllUsesWith (ArrAlloca);
125
+ AI.eraseFromParent ();
126
+ return true ;
127
+ }
128
+
86
129
bool DataScalarizerVisitor::visitLoadInst (LoadInst &LI) {
87
130
unsigned NumOperands = LI.getNumOperands ();
88
131
for (unsigned I = 0 ; I < NumOperands; ++I) {
@@ -154,20 +197,6 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
154
197
return true ;
155
198
}
156
199
157
- // Recursively Creates and Array like version of the given vector like type.
158
- static Type *replaceVectorWithArray (Type *T, LLVMContext &Ctx) {
159
- if (auto *VecTy = dyn_cast<VectorType>(T))
160
- return ArrayType::get (VecTy->getElementType (),
161
- dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
162
- if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
163
- Type *NewElementType =
164
- replaceVectorWithArray (ArrayTy->getElementType (), Ctx);
165
- return ArrayType::get (NewElementType, ArrayTy->getNumElements ());
166
- }
167
- // If it's not a vector or array, return the original type.
168
- return T;
169
- }
170
-
171
200
Constant *transformInitializer (Constant *Init, Type *OrigType, Type *NewType,
172
201
LLVMContext &Ctx) {
173
202
// Handle ConstantAggregateZero (zero-initialized constants)
@@ -253,20 +282,15 @@ static bool findAndReplaceVectors(Module &M) {
253
282
// Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
254
283
// type equality. Instead we will use the visitor pattern.
255
284
Impl.GlobalMap [&G] = NewGlobal;
256
- for (User *U : make_early_inc_range (G.users ())) {
257
- if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
258
- ConstantExpr *CE = cast<ConstantExpr>(U);
259
- for (User *UCE : make_early_inc_range (CE->users ())) {
260
- if (Instruction *Inst = dyn_cast<Instruction>(UCE))
261
- Impl.visit (*Inst);
262
- }
263
- }
264
- if (Instruction *Inst = dyn_cast<Instruction>(U))
265
- Impl.visit (*Inst);
266
- }
267
285
}
268
286
}
269
287
288
+ for (auto &F : make_early_inc_range (M.functions ())) {
289
+ if (F.isDeclaration ())
290
+ continue ;
291
+ MadeChange |= Impl.visit (F);
292
+ }
293
+
270
294
// Remove the old globals after the iteration
271
295
for (auto &[Old, New] : Impl.GlobalMap ) {
272
296
Old->eraseFromParent ();
0 commit comments