Skip to content

Commit 72baf72

Browse files
committed
address pr comments
1 parent 6f6e42c commit 72baf72

File tree

5 files changed

+95
-61
lines changed

5 files changed

+95
-61
lines changed

llvm/lib/Target/DirectX/DXILDataScalarization.cpp

Lines changed: 61 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization----===//
1+
//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
7-
//===----------------------------------------------------------------===//
7+
//===---------------------------------------------------------------------===//
88

99
#include "DXILDataScalarization.h"
1010
#include "DirectX.h"
1111
#include "llvm/ADT/PostOrderIterator.h"
1212
#include "llvm/ADT/STLExtras.h"
13+
#include "llvm/Analysis/DXILResource.h"
1314
#include "llvm/IR/GlobalVariable.h"
1415
#include "llvm/IR/IRBuilder.h"
1516
#include "llvm/IR/InstVisitor.h"
@@ -22,11 +23,21 @@
2223
#include "llvm/Transforms/Utils/Local.h"
2324

2425
#define DEBUG_TYPE "dxil-data-scalarization"
25-
#define Max_VEC_SIZE 4
26+
static const int MaxVecSize = 4;
2627

2728
using namespace llvm;
2829

29-
static void findAndReplaceVectors(Module &M);
30+
class DXILDataScalarizationLegacy : public ModulePass {
31+
32+
public:
33+
bool runOnModule(Module &M) override;
34+
DXILDataScalarizationLegacy() : ModulePass(ID) {}
35+
36+
void getAnalysisUsage(AnalysisUsage &AU) const override;
37+
static char ID; // Pass identification.
38+
};
39+
40+
static bool findAndReplaceVectors(Module &M);
3041

3142
class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
3243
public:
@@ -51,10 +62,10 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
5162
bool visitStoreInst(StoreInst &SI);
5263
bool visitCallInst(CallInst &ICI) { return false; }
5364
bool visitFreezeInst(FreezeInst &FI) { return false; }
54-
friend void findAndReplaceVectors(llvm::Module &M);
65+
friend bool findAndReplaceVectors(llvm::Module &M);
5566

5667
private:
57-
GlobalVariable *getNewGlobalIfExists(Value *CurrOperand);
68+
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
5869
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
5970
SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
6071
bool finish();
@@ -81,7 +92,7 @@ bool DataScalarizerVisitor::finish() {
8192
}
8293

8394
GlobalVariable *
84-
DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) {
95+
DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
8596
if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
8697
auto It = GlobalMap.find(OldGlobal);
8798
if (It != GlobalMap.end()) {
@@ -92,43 +103,44 @@ DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) {
92103
}
93104

94105
bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
95-
for (unsigned I = 0; I < LI.getNumOperands(); ++I) {
106+
unsigned NumOperands = LI.getNumOperands();
107+
for (unsigned I = 0; I < NumOperands; ++I) {
96108
Value *CurrOpperand = LI.getOperand(I);
97-
GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
98-
if (NewGlobal)
109+
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
99110
LI.setOperand(I, NewGlobal);
100111
}
101112
return false;
102113
}
103114

104115
bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
105-
for (unsigned I = 0; I < SI.getNumOperands(); ++I) {
116+
unsigned NumOperands = SI.getNumOperands();
117+
for (unsigned I = 0; I < NumOperands; ++I) {
106118
Value *CurrOpperand = SI.getOperand(I);
107-
GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
108-
if (NewGlobal) {
119+
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) {
109120
SI.setOperand(I, NewGlobal);
110121
}
111122
}
112123
return false;
113124
}
114125

115126
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
116-
for (unsigned I = 0; I < GEPI.getNumOperands(); ++I) {
127+
unsigned NumOperands = GEPI.getNumOperands();
128+
for (unsigned I = 0; I < NumOperands; ++I) {
117129
Value *CurrOpperand = GEPI.getOperand(I);
118-
GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
119-
if (NewGlobal) {
120-
IRBuilder<> Builder(&GEPI);
130+
GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand);
131+
if (!NewGlobal)
132+
continue;
133+
IRBuilder<> Builder(&GEPI);
121134

122-
SmallVector<Value *, Max_VEC_SIZE> Indices;
123-
for (auto &Index : GEPI.indices())
124-
Indices.push_back(Index);
135+
SmallVector<Value *, MaxVecSize> Indices;
136+
for (auto &Index : GEPI.indices())
137+
Indices.push_back(Index);
125138

126-
Value *NewGEP =
127-
Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
139+
Value *NewGEP =
140+
Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
128141

129-
GEPI.replaceAllUsesWith(NewGEP);
130-
PotentiallyDeadInstrs.emplace_back(&GEPI);
131-
}
142+
GEPI.replaceAllUsesWith(NewGEP);
143+
PotentiallyDeadInstrs.emplace_back(&GEPI);
132144
}
133145
return true;
134146
}
@@ -137,7 +149,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
137149
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
138150
if (auto *VecTy = dyn_cast<VectorType>(T))
139151
return ArrayType::get(VecTy->getElementType(),
140-
cast<FixedVectorType>(VecTy)->getNumElements());
152+
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
141153
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
142154
Type *NewElementType =
143155
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
@@ -162,7 +174,7 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
162174
// Handle vector to array transformation
163175
if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
164176
// Convert vector initializer to array initializer
165-
SmallVector<Constant *, Max_VEC_SIZE> ArrayElements;
177+
SmallVector<Constant *, MaxVecSize> ArrayElements;
166178
if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
167179
for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
168180
ArrayElements.push_back(ConstVecInit->getOperand(I));
@@ -171,22 +183,19 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
171183
for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
172184
ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
173185
} else {
174-
llvm_unreachable("Expected a ConstantVector or ConstantDataVector for "
175-
"vector initializer!");
186+
assert(false && "Expected a ConstantVector or ConstantDataVector for "
187+
"vector initializer!");
176188
}
177189

178190
return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
179191
}
180192

181193
// Handle array of vectors transformation
182194
if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
183-
184195
auto *ArrayInit = dyn_cast<ConstantArray>(Init);
185-
if (!ArrayInit) {
186-
llvm_unreachable("Expected a ConstantArray for array initializer!");
187-
}
196+
assert(ArrayInit && "Expected a ConstantArray for array initializer!");
188197

189-
SmallVector<Constant *, Max_VEC_SIZE> NewArrayElements;
198+
SmallVector<Constant *, MaxVecSize> NewArrayElements;
190199
for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
191200
// Recursively transform array elements
192201
Constant *NewElemInit = transformInitializer(
@@ -202,7 +211,8 @@ Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
202211
return Init;
203212
}
204213

205-
static void findAndReplaceVectors(Module &M) {
214+
static bool findAndReplaceVectors(Module &M) {
215+
bool MadeChange = false;
206216
LLVMContext &Ctx = M.getContext();
207217
IRBuilder<> Builder(Ctx);
208218
DataScalarizerVisitor Impl;
@@ -212,17 +222,17 @@ static void findAndReplaceVectors(Module &M) {
212222
Type *NewType = replaceVectorWithArray(OrigType, Ctx);
213223
if (OrigType != NewType) {
214224
// Create a new global variable with the updated type
225+
// Note: Initializer is set via transformInitializer
215226
GlobalVariable *NewGlobal = new GlobalVariable(
216227
M, NewType, G.isConstant(), G.getLinkage(),
217-
// Initializer is set via transformInitializer
218228
/*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
219229
G.getThreadLocalMode(), G.getAddressSpace(),
220230
G.isExternallyInitialized());
221231

222232
// Copy relevant attributes
223233
NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
224234
if (G.getAlignment() > 0) {
225-
NewGlobal->setAlignment(Align(G.getAlignment()));
235+
NewGlobal->setAlignment(G.getAlign());
226236
}
227237

228238
if (G.hasInitializer()) {
@@ -253,24 +263,30 @@ static void findAndReplaceVectors(Module &M) {
253263
}
254264

255265
// Remove the old globals after the iteration
256-
for (auto Pair : Impl.GlobalMap) {
257-
GlobalVariable *OldG = Pair.getFirst();
258-
OldG->eraseFromParent();
266+
for (auto &[Old, New] : Impl.GlobalMap) {
267+
Old->eraseFromParent();
268+
MadeChange = true;
259269
}
270+
return MadeChange;
260271
}
261272

262273
PreservedAnalyses DXILDataScalarization::run(Module &M,
263274
ModuleAnalysisManager &) {
264-
findAndReplaceVectors(M);
265-
return PreservedAnalyses::none();
275+
bool MadeChanges = findAndReplaceVectors(M);
276+
if (!MadeChanges)
277+
return PreservedAnalyses::all();
278+
PreservedAnalyses PA;
279+
PA.preserve<DXILResourceAnalysis>();
280+
return PA;
266281
}
267282

268283
bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
269-
findAndReplaceVectors(M);
270-
return true;
284+
return findAndReplaceVectors(M);
271285
}
272286

273-
void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const {}
287+
void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
288+
AU.addPreserved<DXILResourceWrapperPass>();
289+
}
274290

275291
char DXILDataScalarizationLegacy::ID = 0;
276292

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
//===- DXILDataScalarization.h - Prepare LLVM Module for DXIL Data
2-
// Legalization----===//
1+
//===- DXILDataScalarization.h - Perform DXIL Data Legalization -*- C++ -*-===//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
54
// See https://llvm.org/LICENSE.txt for license information.
65
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
76
//
8-
//===------------------------------------------------------------------------------===//
7+
//===---------------------------------------------------------------------===//
8+
99
#ifndef LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
1010
#define LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
1111

@@ -20,16 +20,6 @@ class DXILDataScalarization : public PassInfoMixin<DXILDataScalarization> {
2020
public:
2121
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
2222
};
23-
24-
class DXILDataScalarizationLegacy : public ModulePass {
25-
26-
public:
27-
bool runOnModule(Module &M) override;
28-
DXILDataScalarizationLegacy() : ModulePass(ID) {}
29-
30-
void getAnalysisUsage(AnalysisUsage &AU) const override;
31-
static char ID; // Pass identification.
32-
};
3323
} // namespace llvm
3424

3525
#endif // LLVM_TARGET_DIRECTX_DXILDATASCALARIZATION_H
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
2+
; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
3+
4+
; Make sure we don't touch arrays without vectors and that can recurse multiple-dimension arrays of vectors
5+
6+
@staticArray = internal global [4 x i32] [i32 1, i32 2, i32 3, i32 4], align 4
7+
@"groushared3dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x <4 x i32>]]] zeroinitializer, align 16
8+
9+
; CHECK @staticArray
10+
; CHECK-NOT: @staticArray.scalarized
11+
; CHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16
12+
; CHECK-NOT: @groushared3dArrayofVectors

llvm/test/CodeGen/DirectX/scalar-load.ll

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
22
; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
3+
4+
; Make sure we can load groupshared, static vectors and arrays of vectors
5+
36
@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
47
@"vecData" = external addrspace(3) global <4 x i32>, align 4
58
@staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
6-
@staticArray = internal global [4 x i32] [i32 1, i32 2, i32 3, i32 4], align 4
9+
@"groushared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [ 3 x <4 x i32>]] zeroinitializer, align 16
710

811
; CHECK: @arrayofVecData.scalarized = local_unnamed_addr addrspace(3) global [2 x [3 x float]] zeroinitializer, align 16
912
; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
1013
; CHECK: @staticArrayOfVecData.scalarized = internal global [3 x [4 x i32]] {{\[}}[4 x i32] [i32 1, i32 2, i32 3, i32 4], [4 x i32] [i32 5, i32 6, i32 7, i32 8], [4 x i32] [i32 9, i32 10, i32 11, i32 12]], align 4
11-
; Check @staticArray
14+
; CHECK: @groushared2dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [4 x i32]]] zeroinitializer, align 16
1215

1316
; CHECK-NOT: @arrayofVecData
1417
; CHECK-NOT: @vecData
1518
; CHECK-NOT: @staticArrayOfVecData
16-
; CHECK-NOT: @staticArray.scalarized
19+
; CHECK-NOT: @groushared2dArrayofVectors
20+
1721

1822
; CHECK-LABEL: load_array_vec_test
1923
define <4 x i32> @load_array_vec_test() {
@@ -42,3 +46,13 @@ define <4 x i32> @load_static_array_of_vec_test(i32 %index) {
4246
%4 = load <4 x i32>, <4 x i32>* %3, align 4
4347
ret <4 x i32> %4
4448
}
49+
50+
; CHECK-LABEL: multid_load_test
51+
define <4 x i32> @multid_load_test() {
52+
; CHECK-COUNT-8: load i32, ptr addrspace(3) {{(.*@groushared2dArrayofVectors.scalarized.*|%.*)}}, align 4
53+
; CHECK-NOT: load i32, ptr addrspace(3) {{.*}}, align 4
54+
%1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 0, i32 0), align 4
55+
%2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 1, i32 1), align 4
56+
%3 = add <4 x i32> %1, %2
57+
ret <4 x i32> %3
58+
}

llvm/test/CodeGen/DirectX/scalar-store.ll

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
; RUN: opt -S -dxil-data-scalarization -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
22
; RUN: llc %s -mtriple=dxil-pc-shadermodel6.3-library --filetype=asm -o - | FileCheck %s
33

4+
; Make sure we can store groupshared, static vectors and arrays of vectors
5+
46
@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
57
@"vecData" = external addrspace(3) global <4 x i32>, align 4
68

0 commit comments

Comments
 (0)