Skip to content

Commit 2a4207e

Browse files
authored
[DirectX] Don't limit visitGetElementPtrInst to global ptrs (#144959)
fixes #144608 - there is a getPointerOperandIndex function so we don't need to iterate the operands trying to find the pointer. This resulted in a small cleanup to visitStoreInst and visitLoadInst. - The meat of this change was in visitGetElementPtrInst to account for allocas and not bail when we don't find a global.
1 parent 958dc86 commit 2a4207e

File tree

2 files changed

+72
-51
lines changed

2 files changed

+72
-51
lines changed

llvm/lib/Target/DirectX/DXILDataScalarization.cpp

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
#include "llvm/IR/GlobalVariable.h"
1515
#include "llvm/IR/IRBuilder.h"
1616
#include "llvm/IR/InstVisitor.h"
17+
#include "llvm/IR/Instructions.h"
1718
#include "llvm/IR/Module.h"
1819
#include "llvm/IR/Operator.h"
1920
#include "llvm/IR/PassManager.h"
2021
#include "llvm/IR/ReplaceConstant.h"
2122
#include "llvm/IR/Type.h"
23+
#include "llvm/Support/Casting.h"
2224
#include "llvm/Transforms/Utils/Cloning.h"
2325
#include "llvm/Transforms/Utils/Local.h"
2426

@@ -137,49 +139,42 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
137139
}
138140

139141
bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
140-
unsigned NumOperands = LI.getNumOperands();
141-
for (unsigned I = 0; I < NumOperands; ++I) {
142-
Value *CurrOpperand = LI.getOperand(I);
143-
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
144-
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
145-
GetElementPtrInst *OldGEP =
146-
cast<GetElementPtrInst>(CE->getAsInstruction());
147-
OldGEP->insertBefore(LI.getIterator());
148-
IRBuilder<> Builder(&LI);
149-
LoadInst *NewLoad =
150-
Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
151-
NewLoad->setAlignment(LI.getAlign());
152-
LI.replaceAllUsesWith(NewLoad);
153-
LI.eraseFromParent();
154-
visitGetElementPtrInst(*OldGEP);
155-
return true;
156-
}
157-
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
158-
LI.setOperand(I, NewGlobal);
142+
Value *PtrOperand = LI.getPointerOperand();
143+
ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
144+
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
145+
GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
146+
OldGEP->insertBefore(LI.getIterator());
147+
IRBuilder<> Builder(&LI);
148+
LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
149+
NewLoad->setAlignment(LI.getAlign());
150+
LI.replaceAllUsesWith(NewLoad);
151+
LI.eraseFromParent();
152+
visitGetElementPtrInst(*OldGEP);
153+
return true;
159154
}
155+
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
156+
LI.setOperand(LI.getPointerOperandIndex(), NewGlobal);
160157
return false;
161158
}
162159

163160
bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
164-
unsigned NumOperands = SI.getNumOperands();
165-
for (unsigned I = 0; I < NumOperands; ++I) {
166-
Value *CurrOpperand = SI.getOperand(I);
167-
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
168-
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
169-
GetElementPtrInst *OldGEP =
170-
cast<GetElementPtrInst>(CE->getAsInstruction());
171-
OldGEP->insertBefore(SI.getIterator());
172-
IRBuilder<> Builder(&SI);
173-
StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
174-
NewStore->setAlignment(SI.getAlign());
175-
SI.replaceAllUsesWith(NewStore);
176-
SI.eraseFromParent();
177-
visitGetElementPtrInst(*OldGEP);
178-
return true;
179-
}
180-
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
181-
SI.setOperand(I, NewGlobal);
161+
162+
Value *PtrOperand = SI.getPointerOperand();
163+
ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
164+
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
165+
GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
166+
OldGEP->insertBefore(SI.getIterator());
167+
IRBuilder<> Builder(&SI);
168+
StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
169+
NewStore->setAlignment(SI.getAlign());
170+
SI.replaceAllUsesWith(NewStore);
171+
SI.eraseFromParent();
172+
visitGetElementPtrInst(*OldGEP);
173+
return true;
182174
}
175+
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
176+
SI.setOperand(SI.getPointerOperandIndex(), NewGlobal);
177+
183178
return false;
184179
}
185180

@@ -302,24 +297,35 @@ bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
302297
}
303298

304299
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
305-
306-
unsigned NumOperands = GEPI.getNumOperands();
307-
GlobalVariable *NewGlobal = nullptr;
308-
for (unsigned I = 0; I < NumOperands; ++I) {
309-
Value *CurrOpperand = GEPI.getOperand(I);
310-
NewGlobal = lookupReplacementGlobal(CurrOpperand);
311-
if (NewGlobal)
312-
break;
300+
Value *PtrOperand = GEPI.getPointerOperand();
301+
Type *OrigGEPType = GEPI.getPointerOperandType();
302+
Type *NewGEPType = OrigGEPType;
303+
bool NeedsTransform = false;
304+
305+
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
306+
NewGEPType = NewGlobal->getValueType();
307+
PtrOperand = NewGlobal;
308+
NeedsTransform = true;
309+
} else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
310+
Type *AllocatedType = Alloca->getAllocatedType();
311+
// OrigGEPType might just be a pointer lets make sure
312+
// to add the allocated type so we have a size
313+
if (AllocatedType != OrigGEPType) {
314+
NewGEPType = AllocatedType;
315+
NeedsTransform = true;
316+
}
313317
}
314-
if (!NewGlobal)
318+
319+
// Note: We bail if this isn't a gep touched via alloca or global
320+
// transformations
321+
if (!NeedsTransform)
315322
return false;
316323

317324
IRBuilder<> Builder(&GEPI);
318325
SmallVector<Value *, MaxVecSize> Indices(GEPI.indices());
319326

320-
Value *NewGEP =
321-
Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
322-
GEPI.getName(), GEPI.getNoWrapFlags());
327+
Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
328+
GEPI.getName(), GEPI.getNoWrapFlags());
323329
GEPI.replaceAllUsesWith(NewGEP);
324330
GEPI.eraseFromParent();
325331
return true;
Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
1-
; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=SCHECK
2-
; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=FCHECK
1+
; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=SCHECK,CHECK
2+
; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=FCHECK,CHECK
33

44
; CHECK-LABEL: alloca_2d__vec_test
55
define void @alloca_2d__vec_test() local_unnamed_addr #2 {
66
; SCHECK: alloca [2 x [4 x i32]], align 16
77
; FCHECK: alloca [8 x i32], align 16
8+
; CHECK: ret void
89
%1 = alloca [2 x <4 x i32>], align 16
910
ret void
1011
}
12+
13+
; CHECK-LABEL: alloca_2d_gep_test
14+
define void @alloca_2d_gep_test() {
15+
; SCHECK: [[alloca_val:%.*]] = alloca [2 x [2 x i32]], align 16
16+
; FCHECK: [[alloca_val:%.*]] = alloca [4 x i32], align 16
17+
; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
18+
; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [2 x [2 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]]
19+
; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr [[alloca_val]], i32 0, i32 [[tid]]
20+
; CHECK: ret void
21+
%1 = alloca [2 x <2 x i32>], align 16
22+
%2 = tail call i32 @llvm.dx.thread.id(i32 0)
23+
%3 = getelementptr inbounds nuw [2 x <2 x i32>], ptr %1, i32 0, i32 %2
24+
ret void
25+
}

0 commit comments

Comments
 (0)