Skip to content

Commit 909212f

Browse files
authored
[DirectX] Scalarize Allocas as part of data scalarization (#140165)
- DXILDataScalarization should not just be limited to global data - Add a scalarization for alloca - Add ReversePostOrderTraversal of functions and iterate over basic blocks and run DataScalarizerVisitor. - fixes #140143
1 parent 6a738f6 commit 909212f

File tree

3 files changed

+75
-35
lines changed

3 files changed

+75
-35
lines changed

llvm/lib/Target/DirectX/DXILDataScalarization.cpp

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "DirectX.h"
1111
#include "llvm/ADT/PostOrderIterator.h"
1212
#include "llvm/ADT/STLExtras.h"
13+
#include "llvm/IR/DerivedTypes.h"
1314
#include "llvm/IR/GlobalVariable.h"
1415
#include "llvm/IR/IRBuilder.h"
1516
#include "llvm/IR/InstVisitor.h"
@@ -40,9 +41,10 @@ static bool findAndReplaceVectors(Module &M);
4041
class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
4142
public:
4243
DataScalarizerVisitor() : GlobalMap() {}
43-
bool visit(Instruction &I);
44+
bool visit(Function &F);
4445
// InstVisitor methods. They return true if the instruction was scalarized,
4546
// false if nothing changed.
47+
bool visitAllocaInst(AllocaInst &AI);
4648
bool visitInstruction(Instruction &I) { return false; }
4749
bool visitSelectInst(SelectInst &SI) { return false; }
4850
bool visitICmpInst(ICmpInst &ICI) { return false; }
@@ -67,9 +69,14 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
6769
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
6870
};
6971

70-
bool DataScalarizerVisitor::visit(Instruction &I) {
71-
assert(!GlobalMap.empty());
72-
return InstVisitor::visit(I);
72+
bool DataScalarizerVisitor::visit(Function &F) {
73+
bool MadeChange = false;
74+
ReversePostOrderTraversal<Function *> RPOT(&F);
75+
for (BasicBlock *BB : make_early_inc_range(RPOT)) {
76+
for (Instruction &I : make_early_inc_range(*BB))
77+
MadeChange |= InstVisitor::visit(I);
78+
}
79+
return MadeChange;
7380
}
7481

7582
GlobalVariable *
@@ -83,6 +90,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
8390
return nullptr; // Not found
8491
}
8592

93+
// Recursively creates an array version of the given vector type.
94+
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
95+
if (auto *VecTy = dyn_cast<VectorType>(T))
96+
return ArrayType::get(VecTy->getElementType(),
97+
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
98+
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
99+
Type *NewElementType =
100+
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
101+
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
102+
}
103+
// If it's not a vector or array, return the original type.
104+
return T;
105+
}
106+
107+
static bool isArrayOfVectors(Type *T) {
108+
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
109+
return isa<VectorType>(ArrType->getElementType());
110+
return false;
111+
}
112+
113+
bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
114+
if (!isArrayOfVectors(AI.getAllocatedType()))
115+
return false;
116+
117+
ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
118+
IRBuilder<> Builder(&AI);
119+
LLVMContext &Ctx = AI.getContext();
120+
Type *NewType = replaceVectorWithArray(ArrType, Ctx);
121+
AllocaInst *ArrAlloca =
122+
Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
123+
ArrAlloca->setAlignment(AI.getAlign());
124+
AI.replaceAllUsesWith(ArrAlloca);
125+
AI.eraseFromParent();
126+
return true;
127+
}
128+
86129
bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
87130
unsigned NumOperands = LI.getNumOperands();
88131
for (unsigned I = 0; I < NumOperands; ++I) {
@@ -154,20 +197,6 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
154197
return true;
155198
}
156199

157-
// Recursively Creates and Array like version of the given vector like type.
158-
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
159-
if (auto *VecTy = dyn_cast<VectorType>(T))
160-
return ArrayType::get(VecTy->getElementType(),
161-
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
162-
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
163-
Type *NewElementType =
164-
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
165-
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
166-
}
167-
// If it's not a vector or array, return the original type.
168-
return T;
169-
}
170-
171200
Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
172201
LLVMContext &Ctx) {
173202
// Handle ConstantAggregateZero (zero-initialized constants)
@@ -253,20 +282,15 @@ static bool findAndReplaceVectors(Module &M) {
253282
// Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
254283
// type equality. Instead we will use the visitor pattern.
255284
Impl.GlobalMap[&G] = NewGlobal;
256-
for (User *U : make_early_inc_range(G.users())) {
257-
if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
258-
ConstantExpr *CE = cast<ConstantExpr>(U);
259-
for (User *UCE : make_early_inc_range(CE->users())) {
260-
if (Instruction *Inst = dyn_cast<Instruction>(UCE))
261-
Impl.visit(*Inst);
262-
}
263-
}
264-
if (Instruction *Inst = dyn_cast<Instruction>(U))
265-
Impl.visit(*Inst);
266-
}
267285
}
268286
}
269287

288+
for (auto &F : make_early_inc_range(M.functions())) {
289+
if (F.isDeclaration())
290+
continue;
291+
MadeChange |= Impl.visit(F);
292+
}
293+
270294
// Remove the old globals after the iteration
271295
for (auto &[Old, New] : Impl.GlobalMap) {
272296
Old->eraseFromParent();

llvm/test/CodeGen/DirectX/scalar-bug-117273.ll

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@
88
define internal void @main() #1 {
99
; CHECK-LABEL: define internal void @main() {
1010
; CHECK-NEXT: [[ENTRY:.*:]]
11-
; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), align 16
12-
; CHECK-NEXT: [[DOTI1:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 1), align 4
13-
; CHECK-NEXT: [[DOTI2:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 2), align 8
14-
; CHECK-NEXT: [[DOTI01:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), align 16
15-
; CHECK-NEXT: [[DOTI12:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 1), align 4
16-
; CHECK-NEXT: [[DOTI23:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 2), align 8
11+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 1
12+
; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
13+
; CHECK-NEXT: [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1
14+
; CHECK-NEXT: [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4
15+
; CHECK-NEXT: [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2
16+
; CHECK-NEXT: [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8
17+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 2
18+
; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
19+
; CHECK-NEXT: [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1
20+
; CHECK-NEXT: [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4
21+
; CHECK-NEXT: [[DOTI26:%.*]] = getelementptr float, ptr [[TMP1]], i32 2
22+
; CHECK-NEXT: [[DOTI27:%.*]] = load float, ptr [[DOTI26]], align 8
1723
; CHECK-NEXT: ret void
1824
;
1925
entry:
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=SCHECK
2+
; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=FCHECK
3+
4+
; CHECK-LABEL: alloca_2d__vec_test
5+
define void @alloca_2d__vec_test() local_unnamed_addr #2 {
6+
; SCHECK: alloca [2 x [4 x i32]], align 16
7+
; FCHECK: alloca [8 x i32], align 16
8+
%1 = alloca [2 x <4 x i32>], align 16
9+
ret void
10+
}

0 commit comments

Comments
 (0)