Skip to content

Commit fc6aac7

Browse files
authored
[DirectX] Fix bug where Flatten arrays was only using last index (#144146)
fixes #142836 We added a function called `collectIndicesAndDimsFromGEP` which builds the Indicies and Dims up for the recursive case and the base case. really to solve #142836 we didn't need to add it to the recursive case. The recursive cases exists for gep chains which are ussually two indicies per gep ie ptr index and array index. adding collectIndicesAndDimsFromGEP to the recursive cases means we can now do some mixed mode indexing say we get a case where its not the ussual 2 indicies but instead 3 we can now treat those last two indicies as part of the computation for the flat array index.
1 parent 58d2347 commit fc6aac7

File tree

3 files changed

+109
-13
lines changed

3 files changed

+109
-13
lines changed

llvm/lib/Target/DirectX/DXILFlattenArrays.cpp

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ class DXILFlattenArraysVisitor
8686
Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
8787
ArrayRef<uint64_t> Dims,
8888
IRBuilder<> &Builder);
89+
90+
// Helper function to collect indices and dimensions from a GEP instruction
91+
void collectIndicesAndDimsFromGEP(GetElementPtrInst &GEP,
92+
SmallVectorImpl<Value *> &Indices,
93+
SmallVectorImpl<uint64_t> &Dims,
94+
bool &AllIndicesAreConstInt);
95+
8996
void
9097
recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
9198
ArrayType *FlattenedArrayType, Value *PtrOperand,
@@ -218,6 +225,26 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
218225
return true;
219226
}
220227

228+
void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP(
229+
GetElementPtrInst &GEP, SmallVectorImpl<Value *> &Indices,
230+
SmallVectorImpl<uint64_t> &Dims, bool &AllIndicesAreConstInt) {
231+
232+
Type *CurrentType = GEP.getSourceElementType();
233+
234+
// Note index 0 is the ptr index.
235+
for (Value *Index : llvm::drop_begin(GEP.indices(), 1)) {
236+
Indices.push_back(Index);
237+
AllIndicesAreConstInt &= isa<ConstantInt>(Index);
238+
239+
if (auto *ArrayTy = dyn_cast<ArrayType>(CurrentType)) {
240+
Dims.push_back(ArrayTy->getNumElements());
241+
CurrentType = ArrayTy->getElementType();
242+
} else {
243+
assert(false && "Expected array type in GEP chain");
244+
}
245+
}
246+
}
247+
221248
void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
222249
GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
223250
Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
@@ -226,12 +253,8 @@ void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
226253
if (GEPChainMap.count(&CurrGEP) > 0)
227254
return;
228255

229-
Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
230-
AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
231-
Indices.push_back(LastIndex);
232-
assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
233-
Dims.push_back(
234-
cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
256+
// Collect indices and dimensions from the current GEP
257+
collectIndicesAndDimsFromGEP(CurrGEP, Indices, Dims, AllIndicesAreConstInt);
235258
bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
236259
if (!IsMultiDimArr) {
237260
assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
@@ -316,9 +339,12 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
316339
// Handle zero uses here because there won't be an update via
317340
// a child in the chain later.
318341
if (GEPChainUseCount == 0) {
319-
SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
320-
SmallVector<uint64_t> Dims({ArrType->getNumElements()});
321-
bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]);
342+
SmallVector<Value *> Indices;
343+
SmallVector<uint64_t> Dims;
344+
bool AllIndicesAreConstInt = true;
345+
346+
// Collect indices and dimensions from the GEP
347+
collectIndicesAndDimsFromGEP(GEP, Indices, Dims, AllIndicesAreConstInt);
322348
GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
323349
std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
324350
return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);

llvm/test/CodeGen/DirectX/flatten-array.ll

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,5 +187,75 @@ define void @global_gep_store() {
187187
ret void
188188
}
189189

190+
@g = local_unnamed_addr addrspace(3) global [2 x [2 x float]] zeroinitializer, align 4
191+
define void @two_index_gep() {
192+
; CHECK-LABEL: define void @two_index_gep(
193+
; CHECK: [[THREAD_ID:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
194+
; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[THREAD_ID]], 2
195+
; CHECK-NEXT: [[ADD:%.*]] = add i32 1, [[MUL]]
196+
; CHECK-NEXT: [[GEP_PTR:%.*]] = getelementptr inbounds nuw [4 x float], ptr addrspace(3) @g.1dim, i32 0, i32 [[ADD]]
197+
; CHECK-NEXT: load float, ptr addrspace(3) [[GEP_PTR]], align 4
198+
; CHECK-NEXT: ret void
199+
%1 = tail call i32 @llvm.dx.thread.id(i32 0)
200+
%2 = getelementptr inbounds nuw [2 x [2 x float]], ptr addrspace(3) @g, i32 0, i32 %1, i32 1
201+
%3 = load float, ptr addrspace(3) %2, align 4
202+
ret void
203+
}
204+
205+
define void @two_index_gep_const() {
206+
; CHECK-LABEL: define void @two_index_gep_const(
207+
; CHECK-NEXT: [[GEP_PTR:%.*]] = getelementptr inbounds nuw [4 x float], ptr addrspace(3) @g.1dim, i32 0, i32 3
208+
; CHECK-NEXT: load float, ptr addrspace(3) [[GEP_PTR]], align 4
209+
; CHECK-NEXT: ret void
210+
%1 = getelementptr inbounds nuw [2 x [2 x float]], ptr addrspace(3) @g, i32 0, i32 1, i32 1
211+
%3 = load float, ptr addrspace(3) %1, align 4
212+
ret void
213+
}
214+
215+
define void @gep_4d_index_test() {
216+
; CHECK-LABEL: gep_4d_index_test
217+
; CHECK: [[a:%.*]] = alloca [16 x i32], align 4
218+
; CHECK-NEXT: getelementptr inbounds [16 x i32], ptr %.1dim, i32 0, i32 1
219+
; CHECK-NEXT: getelementptr inbounds [16 x i32], ptr %.1dim, i32 0, i32 3
220+
; CHECK-NEXT: getelementptr inbounds [16 x i32], ptr %.1dim, i32 0, i32 7
221+
; CHECK-NEXT: getelementptr inbounds [16 x i32], ptr %.1dim, i32 0, i32 15
222+
; CHECK-NEXT: ret void
223+
%1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
224+
%2 = getelementptr inbounds [2 x [2 x[2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 0, i32 0, i32 0, i32 1
225+
%3 = getelementptr inbounds [2 x [2 x[2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 0, i32 0, i32 1, i32 1
226+
%4 = getelementptr inbounds [2 x [2 x[2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 0, i32 1, i32 1, i32 1
227+
%5 = getelementptr inbounds [2 x [2 x[2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 1, i32 1, i32 1, i32 1
228+
ret void
229+
}
230+
231+
define void @gep_4d_index_and_gep_chain_mixed() {
232+
; CHECK-LABEL: gep_4d_index_and_gep_chain_mixed
233+
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [16 x i32], align 4
234+
; CHECK-COUNT-16: getelementptr inbounds [16 x i32], ptr [[ALLOCA]], i32 0, i32 {{[0-9]|1[0-5]}}
235+
; CHECK-NEXT: ret void
236+
%1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
237+
%a4d0_0 = getelementptr inbounds [2 x [2 x [2 x [2 x i32]]]], [2 x [2 x[2 x [2 x i32]]]]* %1, i32 0, i32 0, i32 0
238+
%a2d0_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %a4d0_0, i32 0, i32 0, i32 0
239+
%a2d0_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %a4d0_0, i32 0, i32 0, i32 1
240+
%a2d1_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %a4d0_0, i32 0, i32 1, i32 0
241+
%a2d1_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %a4d0_0, i32 0, i32 1, i32 1
242+
%b4d0_1 = getelementptr inbounds [2 x [2 x [2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 0, i32 1
243+
%b2d0_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %b4d0_1, i32 0, i32 0, i32 0
244+
%b2d0_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %b4d0_1, i32 0, i32 0, i32 1
245+
%b2d1_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %b4d0_1, i32 0, i32 1, i32 0
246+
%b2d1_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %b4d0_1, i32 0, i32 1, i32 1
247+
%c4d1_0 = getelementptr inbounds [2 x [2 x [2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 1, i32 0
248+
%c2d0_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %c4d1_0, i32 0, i32 0, i32 0
249+
%c2d0_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %c4d1_0, i32 0, i32 0, i32 1
250+
%c2d1_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %c4d1_0, i32 0, i32 1, i32 0
251+
%c2d1_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %c4d1_0, i32 0, i32 1, i32 1
252+
%g4d1_1 = getelementptr inbounds [2 x [2 x [2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 1, i32 1
253+
%g2d0_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g4d1_1, i32 0, i32 0, i32 0
254+
%g2d0_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g4d1_1, i32 0, i32 0, i32 1
255+
%g2d1_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g4d1_1, i32 0, i32 1, i32 0
256+
%g2d1_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g4d1_1, i32 0, i32 1, i32 1
257+
ret void
258+
}
259+
190260
; Make sure we don't try to walk the body of a function declaration.
191261
declare void @opaque_function()

llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ define <4 x i32> @multid_load_test() #0 {
111111
; CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4
112112
; CHECK-NEXT: [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 3) to ptr addrspace(3)
113113
; CHECK-NEXT: [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4
114-
; CHECK-NEXT: [[TMP9:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 1) to ptr addrspace(3)
114+
; CHECK-NEXT: [[TMP9:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4) to ptr addrspace(3)
115115
; CHECK-NEXT: [[TMP10:%.*]] = load i32, ptr addrspace(3) [[TMP9]], align 4
116-
; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 1), i32 1) to ptr addrspace(3)
116+
; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 1) to ptr addrspace(3)
117117
; CHECK-NEXT: [[TMP12:%.*]] = load i32, ptr addrspace(3) [[TMP11]], align 4
118-
; CHECK-NEXT: [[TMP13:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 1), i32 2) to ptr addrspace(3)
118+
; CHECK-NEXT: [[TMP13:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 2) to ptr addrspace(3)
119119
; CHECK-NEXT: [[TMP14:%.*]] = load i32, ptr addrspace(3) [[TMP13]], align 4
120-
; CHECK-NEXT: [[TMP15:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 1), i32 3) to ptr addrspace(3)
120+
; CHECK-NEXT: [[TMP15:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 3) to ptr addrspace(3)
121121
; CHECK-NEXT: [[TMP16:%.*]] = load i32, ptr addrspace(3) [[TMP15]], align 4
122122
; CHECK-NEXT: [[DOTI05:%.*]] = add i32 [[TMP2]], [[TMP10]]
123123
; CHECK-NEXT: [[DOTI16:%.*]] = add i32 [[TMP4]], [[TMP12]]

0 commit comments

Comments
 (0)