Skip to content

[DirectX] Scalarize extractelement and insertelement with dynamic indices #141676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 20, 2025
170 changes: 149 additions & 21 deletions llvm/lib/Target/DirectX/DXILDataScalarization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ static const int MaxVecSize = 4;

using namespace llvm;

// Recursively creates an array-like version of a given vector type.
static Type *equivalentArrayTypeFromVector(Type *T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine I would have just made a function declaration at the top so the implementation could live anywhere.

if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
Type *NewElementType =
equivalentArrayTypeFromVector(ArrayTy->getElementType());
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
}
// If it's not a vector or array, return the original type.
return T;
}

class DXILDataScalarizationLegacy : public ModulePass {

public:
Expand Down Expand Up @@ -54,8 +68,8 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
bool visitCastInst(CastInst &CI) { return false; }
bool visitBitCastInst(BitCastInst &BCI) { return false; }
bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
bool visitInsertElementInst(InsertElementInst &IEI);
bool visitExtractElementInst(ExtractElementInst &EEI);
bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
bool visitPHINode(PHINode &PHI) { return false; }
bool visitLoadInst(LoadInst &LI);
Expand All @@ -65,6 +79,16 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
friend bool findAndReplaceVectors(llvm::Module &M);

private:
typedef std::pair<AllocaInst *, SmallVector<Value *, 4>> AllocaAndGEPs;
typedef SmallDenseMap<Value *, AllocaAndGEPs>
VectorToArrayMap; // A map from a vector-typed Value to its corresponding
// AllocaInst and GEPs to each element of an array
VectorToArrayMap VectorAllocaMap;
AllocaAndGEPs createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
const Twine &Name);
bool replaceDynamicInsertElementInst(InsertElementInst &IEI);
bool replaceDynamicExtractElementInst(ExtractElementInst &EEI);

GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
};
Expand All @@ -76,6 +100,7 @@ bool DataScalarizerVisitor::visit(Function &F) {
for (Instruction &I : make_early_inc_range(*BB))
MadeChange |= InstVisitor::visit(I);
}
VectorAllocaMap.clear();
return MadeChange;
}

Expand All @@ -90,20 +115,6 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
return nullptr; // Not found
}

// Recursively creates an array version of the given vector type.
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
Type *NewElementType =
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
}
// If it's not a vector or array, return the original type.
return T;
}

static bool isArrayOfVectors(Type *T) {
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
return isa<VectorType>(ArrType->getElementType());
Expand All @@ -116,8 +127,7 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {

ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
IRBuilder<> Builder(&AI);
LLVMContext &Ctx = AI.getContext();
Type *NewType = replaceVectorWithArray(ArrType, Ctx);
Type *NewType = equivalentArrayTypeFromVector(ArrType);
AllocaInst *ArrAlloca =
Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
ArrAlloca->setAlignment(AI.getAlign());
Expand Down Expand Up @@ -173,6 +183,124 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
return false;
}

DataScalarizerVisitor::AllocaAndGEPs
DataScalarizerVisitor::createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
const Twine &Name = "") {
// If there is already an alloca for this vector, return it
if (VectorAllocaMap.contains(Vec))
return VectorAllocaMap[Vec];

auto InsertPoint = Builder.GetInsertPoint();

// Allocate the array to hold the vector elements
Builder.SetInsertPointPastAllocas(Builder.GetInsertBlock()->getParent());
Type *ArrTy = equivalentArrayTypeFromVector(Vec->getType());
AllocaInst *ArrAlloca =
Builder.CreateAlloca(ArrTy, nullptr, Name + ".alloca");
const uint64_t ArrNumElems = ArrTy->getArrayNumElements();

// Create loads and stores to populate the array immediately after the
// original vector's defining instruction if available, else immediately after
// the alloca
if (auto *Instr = dyn_cast<Instruction>(Vec))
Builder.SetInsertPoint(Instr->getNextNonDebugInstruction());
SmallVector<Value *, 4> GEPs(ArrNumElems);
for (unsigned I = 0; I < ArrNumElems; ++I) {
Value *EE = Builder.CreateExtractElement(Vec, I, Name + ".extract");
GEPs[I] = Builder.CreateInBoundsGEP(
ArrTy, ArrAlloca, {Builder.getInt32(0), Builder.getInt32(I)},
Name + ".index");
Builder.CreateStore(EE, GEPs[I]);
}

VectorAllocaMap.insert({Vec, {ArrAlloca, GEPs}});
Builder.SetInsertPoint(InsertPoint);
return {ArrAlloca, GEPs};
}

/// Returns a pair of Value* with the first being a GEP into ArrAlloca using
/// indices {0, Index}, and the second Value* being a Load of the GEP
static std::pair<Value *, Value *>
dynamicallyLoadArray(IRBuilder<> &Builder, AllocaInst *ArrAlloca, Value *Index,
const Twine &Name = "") {
Type *ArrTy = ArrAlloca->getAllocatedType();
Value *GEP = Builder.CreateInBoundsGEP(
ArrTy, ArrAlloca, {Builder.getInt32(0), Index}, Name + ".index");
Value *Load =
Builder.CreateLoad(ArrTy->getArrayElementType(), GEP, Name + ".load");
return std::make_pair(GEP, Load);
}

bool DataScalarizerVisitor::replaceDynamicInsertElementInst(
InsertElementInst &IEI) {
IRBuilder<> Builder(&IEI);

Value *Vec = IEI.getOperand(0);
Value *Val = IEI.getOperand(1);
Value *Index = IEI.getOperand(2);

AllocaAndGEPs ArrAllocaAndGEPs =
createArrayFromVector(Builder, Vec, IEI.getName());
AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;
Type *ArrTy = ArrAlloca->getAllocatedType();
SmallVector<Value *, 4> &ArrGEPs = ArrAllocaAndGEPs.second;

auto GEPAndLoad =
dynamicallyLoadArray(Builder, ArrAlloca, Index, IEI.getName());
Value *GEP = GEPAndLoad.first;
Value *Load = GEPAndLoad.second;

Builder.CreateStore(Val, GEP);
Value *NewIEI = PoisonValue::get(Vec->getType());
for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) {
Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), ArrGEPs[I],
IEI.getName() + ".load");
NewIEI = Builder.CreateInsertElement(NewIEI, Load, Builder.getInt32(I),
IEI.getName() + ".insert");
}

// Store back the original value so the Alloca can be reused for subsequent
// insertelement instructions on the same vector
Builder.CreateStore(Load, GEP);

IEI.replaceAllUsesWith(NewIEI);
IEI.eraseFromParent();
return true;
}

bool DataScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
// If the index is a constant then we don't need to scalarize it
Value *Index = IEI.getOperand(2);
if (isa<ConstantInt>(Index))
return false;
return replaceDynamicInsertElementInst(IEI);
}

bool DataScalarizerVisitor::replaceDynamicExtractElementInst(
ExtractElementInst &EEI) {
IRBuilder<> Builder(&EEI);

AllocaAndGEPs ArrAllocaAndGEPs =
createArrayFromVector(Builder, EEI.getVectorOperand(), EEI.getName());
AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;

auto GEPAndLoad = dynamicallyLoadArray(Builder, ArrAlloca,
EEI.getIndexOperand(), EEI.getName());
Value *Load = GEPAndLoad.second;

EEI.replaceAllUsesWith(Load);
EEI.eraseFromParent();
return true;
}

bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
// If the index is a constant then we don't need to scalarize it
Value *Index = EEI.getIndexOperand();
if (isa<ConstantInt>(Index))
return false;
return replaceDynamicExtractElementInst(EEI);
}

bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {

unsigned NumOperands = GEPI.getNumOperands();
Expand All @@ -197,8 +325,8 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
return true;
}

Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
LLVMContext &Ctx) {
static Constant *transformInitializer(Constant *Init, Type *OrigType,
Type *NewType, LLVMContext &Ctx) {
// Handle ConstantAggregateZero (zero-initialized constants)
if (isa<ConstantAggregateZero>(Init)) {
return ConstantAggregateZero::get(NewType);
Expand Down Expand Up @@ -257,7 +385,7 @@ static bool findAndReplaceVectors(Module &M) {
for (GlobalVariable &G : M.globals()) {
Type *OrigType = G.getValueType();

Type *NewType = replaceVectorWithArray(OrigType, Ctx);
Type *NewType = equivalentArrayTypeFromVector(OrigType);
if (OrigType != NewType) {
// Create a new global variable with the updated type
// Note: Initializer is set via transformInitializer
Expand Down
Loading
Loading