1
- // ===- DXILDataScalarization.cpp - Perform DXIL Data Legalization----===//
1
+ // ===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ----- ----===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
- // ===----------------------------------------------------------------===//
7
+ // ===--------------------------------------------------------------------- ===//
8
8
9
9
#include " DXILDataScalarization.h"
10
10
#include " DirectX.h"
11
11
#include " llvm/ADT/PostOrderIterator.h"
12
12
#include " llvm/ADT/STLExtras.h"
13
+ #include " llvm/Analysis/DXILResource.h"
13
14
#include " llvm/IR/GlobalVariable.h"
14
15
#include " llvm/IR/IRBuilder.h"
15
16
#include " llvm/IR/InstVisitor.h"
22
23
#include " llvm/Transforms/Utils/Local.h"
23
24
24
25
#define DEBUG_TYPE " dxil-data-scalarization"
25
- # define Max_VEC_SIZE 4
26
+ static const int MaxVecSize = 4 ;
26
27
27
28
using namespace llvm ;
28
29
29
- static void findAndReplaceVectors (Module &M);
30
+ class DXILDataScalarizationLegacy : public ModulePass {
31
+
32
+ public:
33
+ bool runOnModule (Module &M) override ;
34
+ DXILDataScalarizationLegacy () : ModulePass(ID) {}
35
+
36
+ void getAnalysisUsage (AnalysisUsage &AU) const override ;
37
+ static char ID; // Pass identification.
38
+ };
39
+
40
+ static bool findAndReplaceVectors (Module &M);
30
41
31
42
class DataScalarizerVisitor : public InstVisitor <DataScalarizerVisitor, bool > {
32
43
public:
@@ -51,10 +62,10 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
51
62
bool visitStoreInst (StoreInst &SI);
52
63
bool visitCallInst (CallInst &ICI) { return false ; }
53
64
bool visitFreezeInst (FreezeInst &FI) { return false ; }
54
- friend void findAndReplaceVectors (llvm::Module &M);
65
+ friend bool findAndReplaceVectors (llvm::Module &M);
55
66
56
67
private:
57
- GlobalVariable *getNewGlobalIfExists (Value *CurrOperand);
68
+ GlobalVariable *lookupReplacementGlobal (Value *CurrOperand);
58
69
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
59
70
SmallVector<WeakTrackingVH, 32 > PotentiallyDeadInstrs;
60
71
bool finish ();
@@ -81,7 +92,7 @@ bool DataScalarizerVisitor::finish() {
81
92
}
82
93
83
94
GlobalVariable *
84
- DataScalarizerVisitor::getNewGlobalIfExists (Value *CurrOperand) {
95
+ DataScalarizerVisitor::lookupReplacementGlobal (Value *CurrOperand) {
85
96
if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
86
97
auto It = GlobalMap.find (OldGlobal);
87
98
if (It != GlobalMap.end ()) {
@@ -92,43 +103,44 @@ DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) {
92
103
}
93
104
94
105
bool DataScalarizerVisitor::visitLoadInst (LoadInst &LI) {
95
- for (unsigned I = 0 ; I < LI.getNumOperands (); ++I) {
106
+ unsigned NumOperands = LI.getNumOperands ();
107
+ for (unsigned I = 0 ; I < NumOperands; ++I) {
96
108
Value *CurrOpperand = LI.getOperand (I);
97
- GlobalVariable *NewGlobal = getNewGlobalIfExists (CurrOpperand);
98
- if (NewGlobal)
109
+ if (GlobalVariable *NewGlobal = lookupReplacementGlobal (CurrOpperand))
99
110
LI.setOperand (I, NewGlobal);
100
111
}
101
112
return false ;
102
113
}
103
114
104
115
bool DataScalarizerVisitor::visitStoreInst (StoreInst &SI) {
105
- for (unsigned I = 0 ; I < SI.getNumOperands (); ++I) {
116
+ unsigned NumOperands = SI.getNumOperands ();
117
+ for (unsigned I = 0 ; I < NumOperands; ++I) {
106
118
Value *CurrOpperand = SI.getOperand (I);
107
- GlobalVariable *NewGlobal = getNewGlobalIfExists (CurrOpperand);
108
- if (NewGlobal) {
119
+ if (GlobalVariable *NewGlobal = lookupReplacementGlobal (CurrOpperand)) {
109
120
SI.setOperand (I, NewGlobal);
110
121
}
111
122
}
112
123
return false ;
113
124
}
114
125
115
126
bool DataScalarizerVisitor::visitGetElementPtrInst (GetElementPtrInst &GEPI) {
116
- for (unsigned I = 0 ; I < GEPI.getNumOperands (); ++I) {
127
+ unsigned NumOperands = GEPI.getNumOperands ();
128
+ for (unsigned I = 0 ; I < NumOperands; ++I) {
117
129
Value *CurrOpperand = GEPI.getOperand (I);
118
- GlobalVariable *NewGlobal = getNewGlobalIfExists (CurrOpperand);
119
- if (NewGlobal) {
120
- IRBuilder<> Builder (&GEPI);
130
+ GlobalVariable *NewGlobal = lookupReplacementGlobal (CurrOpperand);
131
+ if (!NewGlobal)
132
+ continue ;
133
+ IRBuilder<> Builder (&GEPI);
121
134
122
- SmallVector<Value *, Max_VEC_SIZE > Indices;
123
- for (auto &Index : GEPI.indices ())
124
- Indices.push_back (Index);
135
+ SmallVector<Value *, MaxVecSize > Indices;
136
+ for (auto &Index : GEPI.indices ())
137
+ Indices.push_back (Index);
125
138
126
- Value *NewGEP =
127
- Builder.CreateGEP (NewGlobal->getValueType (), NewGlobal, Indices);
139
+ Value *NewGEP =
140
+ Builder.CreateGEP (NewGlobal->getValueType (), NewGlobal, Indices);
128
141
129
- GEPI.replaceAllUsesWith (NewGEP);
130
- PotentiallyDeadInstrs.emplace_back (&GEPI);
131
- }
142
+ GEPI.replaceAllUsesWith (NewGEP);
143
+ PotentiallyDeadInstrs.emplace_back (&GEPI);
132
144
}
133
145
return true ;
134
146
}
@@ -137,7 +149,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
137
149
static Type *replaceVectorWithArray (Type *T, LLVMContext &Ctx) {
138
150
if (auto *VecTy = dyn_cast<VectorType>(T))
139
151
return ArrayType::get (VecTy->getElementType (),
140
- cast <FixedVectorType>(VecTy)->getNumElements ());
152
+ dyn_cast <FixedVectorType>(VecTy)->getNumElements ());
141
153
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
142
154
Type *NewElementType =
143
155
replaceVectorWithArray (ArrayTy->getElementType (), Ctx);
@@ -162,7 +174,7 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
162
174
// Handle vector to array transformation
163
175
if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
164
176
// Convert vector initializer to array initializer
165
- SmallVector<Constant *, Max_VEC_SIZE > ArrayElements;
177
+ SmallVector<Constant *, MaxVecSize > ArrayElements;
166
178
if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
167
179
for (unsigned I = 0 ; I < ConstVecInit->getNumOperands (); ++I)
168
180
ArrayElements.push_back (ConstVecInit->getOperand (I));
@@ -171,22 +183,19 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
171
183
for (unsigned I = 0 ; I < ConstDataVecInit->getNumElements (); ++I)
172
184
ArrayElements.push_back (ConstDataVecInit->getElementAsConstant (I));
173
185
} else {
174
- llvm_unreachable ( " Expected a ConstantVector or ConstantDataVector for "
175
- " vector initializer!" );
186
+ assert ( false && " Expected a ConstantVector or ConstantDataVector for "
187
+ " vector initializer!" );
176
188
}
177
189
178
190
return ConstantArray::get (cast<ArrayType>(NewType), ArrayElements);
179
191
}
180
192
181
193
// Handle array of vectors transformation
182
194
if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
183
-
184
195
auto *ArrayInit = dyn_cast<ConstantArray>(Init);
185
- if (!ArrayInit) {
186
- llvm_unreachable (" Expected a ConstantArray for array initializer!" );
187
- }
196
+ assert (ArrayInit && " Expected a ConstantArray for array initializer!" );
188
197
189
- SmallVector<Constant *, Max_VEC_SIZE > NewArrayElements;
198
+ SmallVector<Constant *, MaxVecSize > NewArrayElements;
190
199
for (unsigned I = 0 ; I < ArrayTy->getNumElements (); ++I) {
191
200
// Recursively transform array elements
192
201
Constant *NewElemInit = transformInitializer (
@@ -202,7 +211,8 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
202
211
return Init;
203
212
}
204
213
205
- static void findAndReplaceVectors (Module &M) {
214
+ static bool findAndReplaceVectors (Module &M) {
215
+ bool MadeChange = false ;
206
216
LLVMContext &Ctx = M.getContext ();
207
217
IRBuilder<> Builder (Ctx);
208
218
DataScalarizerVisitor Impl;
@@ -212,17 +222,17 @@ static void findAndReplaceVectors(Module &M) {
212
222
Type *NewType = replaceVectorWithArray (OrigType, Ctx);
213
223
if (OrigType != NewType) {
214
224
// Create a new global variable with the updated type
225
+ // Note: Initializer is set via transformInitializer
215
226
GlobalVariable *NewGlobal = new GlobalVariable (
216
227
M, NewType, G.isConstant (), G.getLinkage (),
217
- // Initializer is set via transformInitializer
218
228
/* Initializer=*/ nullptr , G.getName () + " .scalarized" , &G,
219
229
G.getThreadLocalMode (), G.getAddressSpace (),
220
230
G.isExternallyInitialized ());
221
231
222
232
// Copy relevant attributes
223
233
NewGlobal->setUnnamedAddr (G.getUnnamedAddr ());
224
234
if (G.getAlignment () > 0 ) {
225
- NewGlobal->setAlignment (Align (G. getAlignment () ));
235
+ NewGlobal->setAlignment (G. getAlign ( ));
226
236
}
227
237
228
238
if (G.hasInitializer ()) {
@@ -253,24 +263,30 @@ static void findAndReplaceVectors(Module &M) {
253
263
}
254
264
255
265
// Remove the old globals after the iteration
256
- for (auto Pair : Impl.GlobalMap ) {
257
- GlobalVariable *OldG = Pair. getFirst ();
258
- OldG-> eraseFromParent () ;
266
+ for (auto &[Old, New] : Impl.GlobalMap ) {
267
+ Old-> eraseFromParent ();
268
+ MadeChange = true ;
259
269
}
270
+ return MadeChange;
260
271
}
261
272
262
273
PreservedAnalyses DXILDataScalarization::run (Module &M,
263
274
ModuleAnalysisManager &) {
264
- findAndReplaceVectors (M);
265
- return PreservedAnalyses::none ();
275
+ bool MadeChanges = findAndReplaceVectors (M);
276
+ if (!MadeChanges)
277
+ return PreservedAnalyses::all ();
278
+ PreservedAnalyses PA;
279
+ PA.preserve <DXILResourceAnalysis>();
280
+ return PA;
266
281
}
267
282
268
283
bool DXILDataScalarizationLegacy::runOnModule (Module &M) {
269
- findAndReplaceVectors (M);
270
- return true ;
284
+ return findAndReplaceVectors (M);
271
285
}
272
286
273
- void DXILDataScalarizationLegacy::getAnalysisUsage (AnalysisUsage &AU) const {}
287
+ void DXILDataScalarizationLegacy::getAnalysisUsage (AnalysisUsage &AU) const {
288
+ AU.addPreserved <DXILResourceWrapperPass>();
289
+ }
274
290
275
291
char DXILDataScalarizationLegacy::ID = 0 ;
276
292
0 commit comments