Skip to content

Commit 3f42c6b

Browse files
authored
[DirectX] Scalarize extractelement and insertelement with dynamic indices (#141676)
Fixes #141136 - Implement `visitExtractElementInst` and `visitInsertElementInst` in `DXILDataScalarizerVisitor` to scalarize `extractelement` and `insertelement` instructions whose index operand is not a `ConstantInt` by converting the vector to an array and then loading from the array - Rename the `replaceVectorWithArray` helper function to `equivalentArrayTypeFromVector`, relocate the function toward the top of the file, and remove the unused `Ctx` parameter
1 parent 521adc9 commit 3f42c6b

File tree

2 files changed

+331
-21
lines changed

2 files changed

+331
-21
lines changed

llvm/lib/Target/DirectX/DXILDataScalarization.cpp

Lines changed: 149 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ static const int MaxVecSize = 4;
2727

2828
using namespace llvm;
2929

30+
// Recursively creates an array-like version of a given vector type.
31+
static Type *equivalentArrayTypeFromVector(Type *T) {
32+
if (auto *VecTy = dyn_cast<VectorType>(T))
33+
return ArrayType::get(VecTy->getElementType(),
34+
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
35+
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
36+
Type *NewElementType =
37+
equivalentArrayTypeFromVector(ArrayTy->getElementType());
38+
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
39+
}
40+
// If it's not a vector or array, return the original type.
41+
return T;
42+
}
43+
3044
class DXILDataScalarizationLegacy : public ModulePass {
3145

3246
public:
@@ -54,8 +68,8 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
5468
bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
5569
bool visitCastInst(CastInst &CI) { return false; }
5670
bool visitBitCastInst(BitCastInst &BCI) { return false; }
57-
bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
58-
bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
71+
bool visitInsertElementInst(InsertElementInst &IEI);
72+
bool visitExtractElementInst(ExtractElementInst &EEI);
5973
bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
6074
bool visitPHINode(PHINode &PHI) { return false; }
6175
bool visitLoadInst(LoadInst &LI);
@@ -65,6 +79,16 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
6579
friend bool findAndReplaceVectors(llvm::Module &M);
6680

6781
private:
82+
typedef std::pair<AllocaInst *, SmallVector<Value *, 4>> AllocaAndGEPs;
83+
typedef SmallDenseMap<Value *, AllocaAndGEPs>
84+
VectorToArrayMap; // A map from a vector-typed Value to its corresponding
85+
// AllocaInst and GEPs to each element of an array
86+
VectorToArrayMap VectorAllocaMap;
87+
AllocaAndGEPs createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
88+
const Twine &Name);
89+
bool replaceDynamicInsertElementInst(InsertElementInst &IEI);
90+
bool replaceDynamicExtractElementInst(ExtractElementInst &EEI);
91+
6892
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
6993
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
7094
};
@@ -76,6 +100,7 @@ bool DataScalarizerVisitor::visit(Function &F) {
76100
for (Instruction &I : make_early_inc_range(*BB))
77101
MadeChange |= InstVisitor::visit(I);
78102
}
103+
VectorAllocaMap.clear();
79104
return MadeChange;
80105
}
81106

@@ -90,20 +115,6 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
90115
return nullptr; // Not found
91116
}
92117

93-
// Recursively creates an array version of the given vector type.
94-
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
95-
if (auto *VecTy = dyn_cast<VectorType>(T))
96-
return ArrayType::get(VecTy->getElementType(),
97-
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
98-
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
99-
Type *NewElementType =
100-
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
101-
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
102-
}
103-
// If it's not a vector or array, return the original type.
104-
return T;
105-
}
106-
107118
static bool isArrayOfVectors(Type *T) {
108119
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
109120
return isa<VectorType>(ArrType->getElementType());
@@ -116,8 +127,7 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
116127

117128
ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
118129
IRBuilder<> Builder(&AI);
119-
LLVMContext &Ctx = AI.getContext();
120-
Type *NewType = replaceVectorWithArray(ArrType, Ctx);
130+
Type *NewType = equivalentArrayTypeFromVector(ArrType);
121131
AllocaInst *ArrAlloca =
122132
Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
123133
ArrAlloca->setAlignment(AI.getAlign());
@@ -173,6 +183,124 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
173183
return false;
174184
}
175185

186+
DataScalarizerVisitor::AllocaAndGEPs
187+
DataScalarizerVisitor::createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
188+
const Twine &Name = "") {
189+
// If there is already an alloca for this vector, return it
190+
if (VectorAllocaMap.contains(Vec))
191+
return VectorAllocaMap[Vec];
192+
193+
auto InsertPoint = Builder.GetInsertPoint();
194+
195+
// Allocate the array to hold the vector elements
196+
Builder.SetInsertPointPastAllocas(Builder.GetInsertBlock()->getParent());
197+
Type *ArrTy = equivalentArrayTypeFromVector(Vec->getType());
198+
AllocaInst *ArrAlloca =
199+
Builder.CreateAlloca(ArrTy, nullptr, Name + ".alloca");
200+
const uint64_t ArrNumElems = ArrTy->getArrayNumElements();
201+
202+
// Create loads and stores to populate the array immediately after the
203+
// original vector's defining instruction if available, else immediately after
204+
// the alloca
205+
if (auto *Instr = dyn_cast<Instruction>(Vec))
206+
Builder.SetInsertPoint(Instr->getNextNonDebugInstruction());
207+
SmallVector<Value *, 4> GEPs(ArrNumElems);
208+
for (unsigned I = 0; I < ArrNumElems; ++I) {
209+
Value *EE = Builder.CreateExtractElement(Vec, I, Name + ".extract");
210+
GEPs[I] = Builder.CreateInBoundsGEP(
211+
ArrTy, ArrAlloca, {Builder.getInt32(0), Builder.getInt32(I)},
212+
Name + ".index");
213+
Builder.CreateStore(EE, GEPs[I]);
214+
}
215+
216+
VectorAllocaMap.insert({Vec, {ArrAlloca, GEPs}});
217+
Builder.SetInsertPoint(InsertPoint);
218+
return {ArrAlloca, GEPs};
219+
}
220+
221+
/// Returns a pair of Value* with the first being a GEP into ArrAlloca using
222+
/// indices {0, Index}, and the second Value* being a Load of the GEP
223+
static std::pair<Value *, Value *>
224+
dynamicallyLoadArray(IRBuilder<> &Builder, AllocaInst *ArrAlloca, Value *Index,
225+
const Twine &Name = "") {
226+
Type *ArrTy = ArrAlloca->getAllocatedType();
227+
Value *GEP = Builder.CreateInBoundsGEP(
228+
ArrTy, ArrAlloca, {Builder.getInt32(0), Index}, Name + ".index");
229+
Value *Load =
230+
Builder.CreateLoad(ArrTy->getArrayElementType(), GEP, Name + ".load");
231+
return std::make_pair(GEP, Load);
232+
}
233+
234+
bool DataScalarizerVisitor::replaceDynamicInsertElementInst(
235+
InsertElementInst &IEI) {
236+
IRBuilder<> Builder(&IEI);
237+
238+
Value *Vec = IEI.getOperand(0);
239+
Value *Val = IEI.getOperand(1);
240+
Value *Index = IEI.getOperand(2);
241+
242+
AllocaAndGEPs ArrAllocaAndGEPs =
243+
createArrayFromVector(Builder, Vec, IEI.getName());
244+
AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;
245+
Type *ArrTy = ArrAlloca->getAllocatedType();
246+
SmallVector<Value *, 4> &ArrGEPs = ArrAllocaAndGEPs.second;
247+
248+
auto GEPAndLoad =
249+
dynamicallyLoadArray(Builder, ArrAlloca, Index, IEI.getName());
250+
Value *GEP = GEPAndLoad.first;
251+
Value *Load = GEPAndLoad.second;
252+
253+
Builder.CreateStore(Val, GEP);
254+
Value *NewIEI = PoisonValue::get(Vec->getType());
255+
for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) {
256+
Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), ArrGEPs[I],
257+
IEI.getName() + ".load");
258+
NewIEI = Builder.CreateInsertElement(NewIEI, Load, Builder.getInt32(I),
259+
IEI.getName() + ".insert");
260+
}
261+
262+
// Store back the original value so the Alloca can be reused for subsequent
263+
// insertelement instructions on the same vector
264+
Builder.CreateStore(Load, GEP);
265+
266+
IEI.replaceAllUsesWith(NewIEI);
267+
IEI.eraseFromParent();
268+
return true;
269+
}
270+
271+
bool DataScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
272+
// If the index is a constant then we don't need to scalarize it
273+
Value *Index = IEI.getOperand(2);
274+
if (isa<ConstantInt>(Index))
275+
return false;
276+
return replaceDynamicInsertElementInst(IEI);
277+
}
278+
279+
bool DataScalarizerVisitor::replaceDynamicExtractElementInst(
280+
ExtractElementInst &EEI) {
281+
IRBuilder<> Builder(&EEI);
282+
283+
AllocaAndGEPs ArrAllocaAndGEPs =
284+
createArrayFromVector(Builder, EEI.getVectorOperand(), EEI.getName());
285+
AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;
286+
287+
auto GEPAndLoad = dynamicallyLoadArray(Builder, ArrAlloca,
288+
EEI.getIndexOperand(), EEI.getName());
289+
Value *Load = GEPAndLoad.second;
290+
291+
EEI.replaceAllUsesWith(Load);
292+
EEI.eraseFromParent();
293+
return true;
294+
}
295+
296+
bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
297+
// If the index is a constant then we don't need to scalarize it
298+
Value *Index = EEI.getIndexOperand();
299+
if (isa<ConstantInt>(Index))
300+
return false;
301+
return replaceDynamicExtractElementInst(EEI);
302+
}
303+
176304
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
177305

178306
unsigned NumOperands = GEPI.getNumOperands();
@@ -197,8 +325,8 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
197325
return true;
198326
}
199327

200-
Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
201-
LLVMContext &Ctx) {
328+
static Constant *transformInitializer(Constant *Init, Type *OrigType,
329+
Type *NewType, LLVMContext &Ctx) {
202330
// Handle ConstantAggregateZero (zero-initialized constants)
203331
if (isa<ConstantAggregateZero>(Init)) {
204332
return ConstantAggregateZero::get(NewType);
@@ -257,7 +385,7 @@ static bool findAndReplaceVectors(Module &M) {
257385
for (GlobalVariable &G : M.globals()) {
258386
Type *OrigType = G.getValueType();
259387

260-
Type *NewType = replaceVectorWithArray(OrigType, Ctx);
388+
Type *NewType = equivalentArrayTypeFromVector(OrigType);
261389
if (OrigType != NewType) {
262390
// Create a new global variable with the updated type
263391
// Note: Initializer is set via transformInitializer

0 commit comments

Comments
 (0)