Skip to content

Commit d3d35ad

Browse files
authored
[DirectX] Legalize i8 allocas (#137399)
fixes #137202 investingating i8 allocas I came to find some missing instructions from out i8 legalization around load, store, and select. Added those three. To do i8 allocas right though we needed to walk the uses and find the casts. After finding the casts I chose to pick the smallest cast as the cast to transform to. That would then let me preserve the larger casts that come later
1 parent 7dd8122 commit d3d35ad

File tree

2 files changed

+200
-12
lines changed

2 files changed

+200
-12
lines changed

llvm/lib/Target/DirectX/DXILLegalizePass.cpp

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "llvm/IR/IRBuilder.h"
1313
#include "llvm/IR/InstIterator.h"
1414
#include "llvm/IR/Instruction.h"
15+
#include "llvm/IR/Instructions.h"
1516
#include "llvm/Pass.h"
1617
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
1718
#include <functional>
@@ -31,16 +32,17 @@ static void legalizeFreeze(Instruction &I,
3132
ToRemove.push_back(FI);
3233
}
3334

34-
static void fixI8TruncUseChain(Instruction &I,
35-
SmallVectorImpl<Instruction *> &ToRemove,
36-
DenseMap<Value *, Value *> &ReplacedValues) {
35+
static void fixI8UseChain(Instruction &I,
36+
SmallVectorImpl<Instruction *> &ToRemove,
37+
DenseMap<Value *, Value *> &ReplacedValues) {
3738

3839
auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
3940
Type *InstrType = IntegerType::get(I.getContext(), 32);
4041

4142
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
4243
Value *Op = I.getOperand(OpIdx);
43-
if (ReplacedValues.count(Op))
44+
if (ReplacedValues.count(Op) &&
45+
ReplacedValues[Op]->getType()->isIntegerTy())
4446
InstrType = ReplacedValues[Op]->getType();
4547
}
4648

@@ -73,6 +75,31 @@ static void fixI8TruncUseChain(Instruction &I,
7375
}
7476
}
7577

78+
if (auto *Store = dyn_cast<StoreInst>(&I)) {
79+
if (!Store->getValueOperand()->getType()->isIntegerTy(8))
80+
return;
81+
SmallVector<Value *> NewOperands;
82+
ProcessOperands(NewOperands);
83+
Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]);
84+
ReplacedValues[Store] = NewStore;
85+
ToRemove.push_back(Store);
86+
return;
87+
}
88+
89+
if (auto *Load = dyn_cast<LoadInst>(&I)) {
90+
if (!I.getType()->isIntegerTy(8))
91+
return;
92+
SmallVector<Value *> NewOperands;
93+
ProcessOperands(NewOperands);
94+
Type *ElementType = NewOperands[0]->getType();
95+
if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0]))
96+
ElementType = AI->getAllocatedType();
97+
LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]);
98+
ReplacedValues[Load] = NewLoad;
99+
ToRemove.push_back(Load);
100+
return;
101+
}
102+
76103
if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
77104
if (!I.getType()->isIntegerTy(8))
78105
return;
@@ -81,16 +108,29 @@ static void fixI8TruncUseChain(Instruction &I,
81108
Value *NewInst =
82109
Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
83110
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
84-
if (OBO->hasNoSignedWrap())
85-
cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
86-
if (OBO->hasNoUnsignedWrap())
87-
cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
111+
auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
112+
if (NewBO && OBO->hasNoSignedWrap())
113+
NewBO->setHasNoSignedWrap();
114+
if (NewBO && OBO->hasNoUnsignedWrap())
115+
NewBO->setHasNoUnsignedWrap();
88116
}
89117
ReplacedValues[BO] = NewInst;
90118
ToRemove.push_back(BO);
91119
return;
92120
}
93121

122+
if (auto *Sel = dyn_cast<SelectInst>(&I)) {
123+
if (!I.getType()->isIntegerTy(8))
124+
return;
125+
SmallVector<Value *> NewOperands;
126+
ProcessOperands(NewOperands);
127+
Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1],
128+
NewOperands[2]);
129+
ReplacedValues[Sel] = NewInst;
130+
ToRemove.push_back(Sel);
131+
return;
132+
}
133+
94134
if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
95135
if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
96136
return;
@@ -105,13 +145,61 @@ static void fixI8TruncUseChain(Instruction &I,
105145
}
106146

107147
if (auto *Cast = dyn_cast<CastInst>(&I)) {
108-
if (Cast->getSrcTy()->isIntegerTy(8)) {
109-
ToRemove.push_back(Cast);
110-
Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
148+
if (!Cast->getSrcTy()->isIntegerTy(8))
149+
return;
150+
151+
ToRemove.push_back(Cast);
152+
auto *Replacement = ReplacedValues[Cast->getOperand(0)];
153+
if (Cast->getType() == Replacement->getType()) {
154+
Cast->replaceAllUsesWith(Replacement);
155+
return;
111156
}
157+
Value *AdjustedCast = nullptr;
158+
if (Cast->getOpcode() == Instruction::ZExt)
159+
AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType());
160+
if (Cast->getOpcode() == Instruction::SExt)
161+
AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType());
162+
163+
if (AdjustedCast)
164+
Cast->replaceAllUsesWith(AdjustedCast);
112165
}
113166
}
114167

168+
static void upcastI8AllocasAndUses(Instruction &I,
169+
SmallVectorImpl<Instruction *> &ToRemove,
170+
DenseMap<Value *, Value *> &ReplacedValues) {
171+
auto *AI = dyn_cast<AllocaInst>(&I);
172+
if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
173+
return;
174+
175+
Type *SmallestType = nullptr;
176+
177+
// Gather all cast targets
178+
for (User *U : AI->users()) {
179+
auto *Load = dyn_cast<LoadInst>(U);
180+
if (!Load)
181+
continue;
182+
for (User *LU : Load->users()) {
183+
auto *Cast = dyn_cast<CastInst>(LU);
184+
if (!Cast)
185+
continue;
186+
Type *Ty = Cast->getType();
187+
if (!SmallestType ||
188+
Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
189+
SmallestType = Ty;
190+
}
191+
}
192+
193+
if (!SmallestType)
194+
return; // no valid casts found
195+
196+
// Replace alloca
197+
IRBuilder<> Builder(AI);
198+
auto *NewAlloca = Builder.CreateAlloca(SmallestType);
199+
ReplacedValues[AI] = NewAlloca;
200+
ToRemove.push_back(AI);
201+
}
202+
115203
static void
116204
downcastI64toI32InsertExtractElements(Instruction &I,
117205
SmallVectorImpl<Instruction *> &ToRemove,
@@ -178,7 +266,8 @@ class DXILLegalizationPipeline {
178266
LegalizationPipeline;
179267

180268
void initializeLegalizationPipeline() {
181-
LegalizationPipeline.push_back(fixI8TruncUseChain);
269+
LegalizationPipeline.push_back(upcastI8AllocasAndUses);
270+
LegalizationPipeline.push_back(fixI8UseChain);
182271
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
183272
LegalizationPipeline.push_back(legalizeFreeze);
184273
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
3+
4+
define void @const_i8_store() {
5+
; CHECK-LABEL: define void @const_i8_store() {
6+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
7+
; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
8+
; CHECK-NEXT: store i32 1, ptr [[TMP1]], align 4
9+
; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
10+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
11+
; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4
12+
; CHECK-NEXT: ret void
13+
;
14+
%accum.i.flat = alloca [1 x i32], align 4
15+
%i = alloca i8, align 4
16+
store i8 1, ptr %i
17+
%i8.load = load i8, ptr %i
18+
%z = zext i8 %i8.load to i32
19+
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
20+
store i32 %z, ptr %gep, align 4
21+
ret void
22+
}
23+
24+
define void @const_add_i8_store() {
25+
; CHECK-LABEL: define void @const_add_i8_store() {
26+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
27+
; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
28+
; CHECK-NEXT: store i32 4, ptr [[TMP1]], align 4
29+
; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
30+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
31+
; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4
32+
; CHECK-NEXT: ret void
33+
;
34+
%accum.i.flat = alloca [1 x i32], align 4
35+
%i = alloca i8, align 4
36+
%add_i8 = add nsw i8 3, 1
37+
store i8 %add_i8, ptr %i
38+
%i8.load = load i8, ptr %i
39+
%z = zext i8 %i8.load to i32
40+
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
41+
store i32 %z, ptr %gep, align 4
42+
ret void
43+
}
44+
45+
define void @var_i8_store(i1 %cmp.i8) {
46+
; CHECK-LABEL: define void @var_i8_store(
47+
; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
48+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
49+
; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
50+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
51+
; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4
52+
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[TMP1]], align 4
53+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
54+
; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP]], align 4
55+
; CHECK-NEXT: ret void
56+
;
57+
%accum.i.flat = alloca [1 x i32], align 4
58+
%i = alloca i8, align 4
59+
%select.i8 = select i1 %cmp.i8, i8 1, i8 2
60+
store i8 %select.i8, ptr %i
61+
%i8.load = load i8, ptr %i
62+
%z = zext i8 %i8.load to i32
63+
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
64+
store i32 %z, ptr %gep, align 4
65+
ret void
66+
}
67+
68+
define void @conflicting_cast(i1 %cmp.i8) {
69+
; CHECK-LABEL: define void @conflicting_cast(
70+
; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
71+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x i32], align 4
72+
; CHECK-NEXT: [[TMP1:%.*]] = alloca i16, align 2
73+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
74+
; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4
75+
; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr [[TMP1]], align 2
76+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 0
77+
; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP1]], align 2
78+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 1
79+
; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP2]], align 2
80+
; CHECK-NEXT: [[TMP4:%.*]] = zext i16 [[TMP3]] to i32
81+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 1
82+
; CHECK-NEXT: store i32 [[TMP4]], ptr [[GEP3]], align 4
83+
; CHECK-NEXT: ret void
84+
;
85+
%accum.i.flat = alloca [2 x i32], align 4
86+
%i = alloca i8, align 4
87+
%select.i8 = select i1 %cmp.i8, i8 1, i8 2
88+
store i8 %select.i8, ptr %i
89+
%i8.load = load i8, ptr %i
90+
%z = zext i8 %i8.load to i16
91+
%gep1 = getelementptr i16, ptr %accum.i.flat, i32 0
92+
store i16 %z, ptr %gep1, align 2
93+
%gep2 = getelementptr i16, ptr %accum.i.flat, i32 1
94+
store i16 %z, ptr %gep2, align 2
95+
%z2 = zext i8 %i8.load to i32
96+
%gep3 = getelementptr i32, ptr %accum.i.flat, i32 1
97+
store i32 %z2, ptr %gep3, align 4
98+
ret void
99+
}

0 commit comments

Comments
 (0)