Skip to content

Commit 4cbaf28

Browse files
committed
[DirectX] Legalize i8 allocas
1 parent 82a1d50 commit 4cbaf28

File tree

2 files changed

+140
-9
lines changed

2 files changed

+140
-9
lines changed

llvm/lib/Target/DirectX/DXILLegalizePass.cpp

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "llvm/IR/IRBuilder.h"
1313
#include "llvm/IR/InstIterator.h"
1414
#include "llvm/IR/Instruction.h"
15+
#include "llvm/IR/Instructions.h"
1516
#include "llvm/Pass.h"
1617
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
1718
#include <functional>
@@ -31,16 +32,17 @@ static void legalizeFreeze(Instruction &I,
3132
ToRemove.push_back(FI);
3233
}
3334

34-
static void fixI8TruncUseChain(Instruction &I,
35-
SmallVectorImpl<Instruction *> &ToRemove,
36-
DenseMap<Value *, Value *> &ReplacedValues) {
35+
static void fixI8UseChain(Instruction &I,
36+
SmallVectorImpl<Instruction *> &ToRemove,
37+
DenseMap<Value *, Value *> &ReplacedValues) {
3738

3839
auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
3940
Type *InstrType = IntegerType::get(I.getContext(), 32);
4041

4142
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
4243
Value *Op = I.getOperand(OpIdx);
43-
if (ReplacedValues.count(Op))
44+
if (ReplacedValues.count(Op) &&
45+
ReplacedValues[Op]->getType()->isIntegerTy())
4446
InstrType = ReplacedValues[Op]->getType();
4547
}
4648

@@ -73,6 +75,31 @@ static void fixI8TruncUseChain(Instruction &I,
7375
}
7476
}
7577

78+
if (auto *Store = dyn_cast<StoreInst>(&I)) {
79+
if (!Store->getValueOperand()->getType()->isIntegerTy(8))
80+
return;
81+
SmallVector<Value *> NewOperands;
82+
ProcessOperands(NewOperands);
83+
Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]);
84+
ReplacedValues[Store] = NewStore;
85+
ToRemove.push_back(Store);
86+
return;
87+
}
88+
89+
if (auto *Load = dyn_cast<LoadInst>(&I)) {
90+
if (!I.getType()->isIntegerTy(8))
91+
return;
92+
SmallVector<Value *> NewOperands;
93+
ProcessOperands(NewOperands);
94+
Type *ElementType = NewOperands[0]->getType();
95+
if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0]))
96+
ElementType = AI->getAllocatedType();
97+
LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]);
98+
ReplacedValues[Load] = NewLoad;
99+
ToRemove.push_back(Load);
100+
return;
101+
}
102+
76103
if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
77104
if (!I.getType()->isIntegerTy(8))
78105
return;
@@ -81,16 +108,29 @@ static void fixI8TruncUseChain(Instruction &I,
81108
Value *NewInst =
82109
Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
83110
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
84-
if (OBO->hasNoSignedWrap())
85-
cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
86-
if (OBO->hasNoUnsignedWrap())
87-
cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
111+
auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
112+
if (NewBO && OBO->hasNoSignedWrap())
113+
NewBO->setHasNoSignedWrap();
114+
if (NewBO && OBO->hasNoUnsignedWrap())
115+
NewBO->setHasNoUnsignedWrap();
88116
}
89117
ReplacedValues[BO] = NewInst;
90118
ToRemove.push_back(BO);
91119
return;
92120
}
93121

122+
if (auto *Sel = dyn_cast<SelectInst>(&I)) {
123+
if (!I.getType()->isIntegerTy(8))
124+
return;
125+
SmallVector<Value *> NewOperands;
126+
ProcessOperands(NewOperands);
127+
Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1],
128+
NewOperands[2]);
129+
ReplacedValues[Sel] = NewInst;
130+
ToRemove.push_back(Sel);
131+
return;
132+
}
133+
94134
if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
95135
if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
96136
return;
@@ -105,13 +145,50 @@ static void fixI8TruncUseChain(Instruction &I,
105145
}
106146

107147
if (auto *Cast = dyn_cast<CastInst>(&I)) {
148+
108149
if (Cast->getSrcTy()->isIntegerTy(8)) {
109150
ToRemove.push_back(Cast);
110151
Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
111152
}
112153
}
113154
}
114155

156+
static void upcastI8AllocasAndUses(Instruction &I,
157+
SmallVectorImpl<Instruction *> &ToRemove,
158+
DenseMap<Value *, Value *> &ReplacedValues) {
159+
auto *AI = dyn_cast<AllocaInst>(&I);
160+
if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
161+
return;
162+
163+
std::optional<Type *> TargetType;
164+
bool Conflict = false;
165+
for (User *U : AI->users()) {
166+
auto *Load = dyn_cast<LoadInst>(U);
167+
if (!Load)
168+
continue;
169+
for (User *LU : Load->users()) {
170+
auto *Cast = dyn_cast<CastInst>(LU);
171+
if (!Cast)
172+
continue;
173+
Type *T = Cast->getType();
174+
if (!TargetType)
175+
TargetType = T;
176+
177+
if (TargetType.value() != T) {
178+
Conflict = true;
179+
break;
180+
}
181+
}
182+
}
183+
if (!TargetType || Conflict)
184+
return;
185+
186+
IRBuilder<> Builder(AI);
187+
AllocaInst *NewAlloca = Builder.CreateAlloca(TargetType.value());
188+
ReplacedValues[AI] = NewAlloca;
189+
ToRemove.push_back(AI);
190+
}
191+
115192
static void
116193
downcastI64toI32InsertExtractElements(Instruction &I,
117194
SmallVectorImpl<Instruction *> &ToRemove,
@@ -178,7 +255,8 @@ class DXILLegalizationPipeline {
178255
LegalizationPipeline;
179256

180257
void initializeLegalizationPipeline() {
181-
LegalizationPipeline.push_back(fixI8TruncUseChain);
258+
LegalizationPipeline.push_back(upcastI8AllocasAndUses);
259+
LegalizationPipeline.push_back(fixI8UseChain);
182260
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
183261
LegalizationPipeline.push_back(legalizeFreeze);
184262
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
2+
3+
define void @const_i8_store() {
4+
%accum.i.flat = alloca [1 x i32], align 4
5+
%i = alloca i8, align 4
6+
store i8 1, ptr %i
7+
%i8.load = load i8, ptr %i
8+
%z = zext i8 %i8.load to i32
9+
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
10+
store i32 %z, ptr %gep, align 4
11+
ret void
12+
}
13+
14+
define void @const_add_i8_store() {
15+
%accum.i.flat = alloca [1 x i32], align 4
16+
%i = alloca i8, align 4
17+
%add_i8 = add nsw i8 3, 1
18+
store i8 %add_i8, ptr %i
19+
%i8.load = load i8, ptr %i
20+
%z = zext i8 %i8.load to i32
21+
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
22+
store i32 %z, ptr %gep, align 4
23+
ret void
24+
}
25+
26+
define void @var_i8_store(i1 %cmp.i8) {
27+
%accum.i.flat = alloca [1 x i32], align 4
28+
%i = alloca i8, align 4
29+
%select.i8 = select i1 %cmp.i8, i8 1, i8 2
30+
store i8 %select.i8, ptr %i
31+
%i8.load = load i8, ptr %i
32+
%z = zext i8 %i8.load to i32
33+
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
34+
store i32 %z, ptr %gep, align 4
35+
ret void
36+
}
37+
38+
define void @conflicting_cast(i1 %cmp.i8) {
39+
%accum.i.flat = alloca [2 x i32], align 4
40+
%i = alloca i8, align 4
41+
%select.i8 = select i1 %cmp.i8, i8 1, i8 2
42+
store i8 %select.i8, ptr %i
43+
%i8.load = load i8, ptr %i
44+
%z = zext i8 %i8.load to i16
45+
%gep1 = getelementptr i16, ptr %accum.i.flat, i32 0
46+
store i16 %z, ptr %gep1, align 2
47+
%gep2 = getelementptr i16, ptr %accum.i.flat, i32 1
48+
store i16 %z, ptr %gep2, align 2
49+
%z2 = zext i8 %i8.load to i32
50+
%gep3 = getelementptr i32, ptr %accum.i.flat, i32 1
51+
store i32 %z2, ptr %gep3, align 4
52+
ret void
53+
}

0 commit comments

Comments
 (0)