12
12
#include " llvm/IR/IRBuilder.h"
13
13
#include " llvm/IR/InstIterator.h"
14
14
#include " llvm/IR/Instruction.h"
15
+ #include " llvm/IR/Instructions.h"
15
16
#include " llvm/Pass.h"
16
17
#include " llvm/Transforms/Utils/BasicBlockUtils.h"
17
18
#include < functional>
@@ -31,16 +32,17 @@ static void legalizeFreeze(Instruction &I,
31
32
ToRemove.push_back (FI);
32
33
}
33
34
34
- static void fixI8TruncUseChain (Instruction &I,
35
- SmallVectorImpl<Instruction *> &ToRemove,
36
- DenseMap<Value *, Value *> &ReplacedValues) {
35
+ static void fixI8UseChain (Instruction &I,
36
+ SmallVectorImpl<Instruction *> &ToRemove,
37
+ DenseMap<Value *, Value *> &ReplacedValues) {
37
38
38
39
auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
39
40
Type *InstrType = IntegerType::get (I.getContext (), 32 );
40
41
41
42
for (unsigned OpIdx = 0 ; OpIdx < I.getNumOperands (); ++OpIdx) {
42
43
Value *Op = I.getOperand (OpIdx);
43
- if (ReplacedValues.count (Op))
44
+ if (ReplacedValues.count (Op) &&
45
+ ReplacedValues[Op]->getType ()->isIntegerTy ())
44
46
InstrType = ReplacedValues[Op]->getType ();
45
47
}
46
48
@@ -73,6 +75,31 @@ static void fixI8TruncUseChain(Instruction &I,
73
75
}
74
76
}
75
77
78
+ if (auto *Store = dyn_cast<StoreInst>(&I)) {
79
+ if (!Store->getValueOperand ()->getType ()->isIntegerTy (8 ))
80
+ return ;
81
+ SmallVector<Value *> NewOperands;
82
+ ProcessOperands (NewOperands);
83
+ Value *NewStore = Builder.CreateStore (NewOperands[0 ], NewOperands[1 ]);
84
+ ReplacedValues[Store] = NewStore;
85
+ ToRemove.push_back (Store);
86
+ return ;
87
+ }
88
+
89
+ if (auto *Load = dyn_cast<LoadInst>(&I)) {
90
+ if (!I.getType ()->isIntegerTy (8 ))
91
+ return ;
92
+ SmallVector<Value *> NewOperands;
93
+ ProcessOperands (NewOperands);
94
+ Type *ElementType = NewOperands[0 ]->getType ();
95
+ if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0 ]))
96
+ ElementType = AI->getAllocatedType ();
97
+ LoadInst *NewLoad = Builder.CreateLoad (ElementType, NewOperands[0 ]);
98
+ ReplacedValues[Load] = NewLoad;
99
+ ToRemove.push_back (Load);
100
+ return ;
101
+ }
102
+
76
103
if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
77
104
if (!I.getType ()->isIntegerTy (8 ))
78
105
return ;
@@ -81,16 +108,29 @@ static void fixI8TruncUseChain(Instruction &I,
81
108
Value *NewInst =
82
109
Builder.CreateBinOp (BO->getOpcode (), NewOperands[0 ], NewOperands[1 ]);
83
110
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
84
- if (OBO->hasNoSignedWrap ())
85
- cast<BinaryOperator>(NewInst)->setHasNoSignedWrap ();
86
- if (OBO->hasNoUnsignedWrap ())
87
- cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap ();
111
+ auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
112
+ if (NewBO && OBO->hasNoSignedWrap ())
113
+ NewBO->setHasNoSignedWrap ();
114
+ if (NewBO && OBO->hasNoUnsignedWrap ())
115
+ NewBO->setHasNoUnsignedWrap ();
88
116
}
89
117
ReplacedValues[BO] = NewInst;
90
118
ToRemove.push_back (BO);
91
119
return ;
92
120
}
93
121
122
+ if (auto *Sel = dyn_cast<SelectInst>(&I)) {
123
+ if (!I.getType ()->isIntegerTy (8 ))
124
+ return ;
125
+ SmallVector<Value *> NewOperands;
126
+ ProcessOperands (NewOperands);
127
+ Value *NewInst = Builder.CreateSelect (Sel->getCondition (), NewOperands[1 ],
128
+ NewOperands[2 ]);
129
+ ReplacedValues[Sel] = NewInst;
130
+ ToRemove.push_back (Sel);
131
+ return ;
132
+ }
133
+
94
134
if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
95
135
if (!Cmp->getOperand (0 )->getType ()->isIntegerTy (8 ))
96
136
return ;
@@ -105,13 +145,50 @@ static void fixI8TruncUseChain(Instruction &I,
105
145
}
106
146
107
147
if (auto *Cast = dyn_cast<CastInst>(&I)) {
148
+
108
149
if (Cast->getSrcTy ()->isIntegerTy (8 )) {
109
150
ToRemove.push_back (Cast);
110
151
Cast->replaceAllUsesWith (ReplacedValues[Cast->getOperand (0 )]);
111
152
}
112
153
}
113
154
}
114
155
156
+ static void upcastI8AllocasAndUses (Instruction &I,
157
+ SmallVectorImpl<Instruction *> &ToRemove,
158
+ DenseMap<Value *, Value *> &ReplacedValues) {
159
+ auto *AI = dyn_cast<AllocaInst>(&I);
160
+ if (!AI || !AI->getAllocatedType ()->isIntegerTy (8 ))
161
+ return ;
162
+
163
+ std::optional<Type *> TargetType;
164
+ bool Conflict = false ;
165
+ for (User *U : AI->users ()) {
166
+ auto *Load = dyn_cast<LoadInst>(U);
167
+ if (!Load)
168
+ continue ;
169
+ for (User *LU : Load->users ()) {
170
+ auto *Cast = dyn_cast<CastInst>(LU);
171
+ if (!Cast)
172
+ continue ;
173
+ Type *T = Cast->getType ();
174
+ if (!TargetType)
175
+ TargetType = T;
176
+
177
+ if (TargetType.value () != T) {
178
+ Conflict = true ;
179
+ break ;
180
+ }
181
+ }
182
+ }
183
+ if (!TargetType || Conflict)
184
+ return ;
185
+
186
+ IRBuilder<> Builder (AI);
187
+ AllocaInst *NewAlloca = Builder.CreateAlloca (TargetType.value ());
188
+ ReplacedValues[AI] = NewAlloca;
189
+ ToRemove.push_back (AI);
190
+ }
191
+
115
192
static void
116
193
downcastI64toI32InsertExtractElements (Instruction &I,
117
194
SmallVectorImpl<Instruction *> &ToRemove,
@@ -178,7 +255,8 @@ class DXILLegalizationPipeline {
178
255
LegalizationPipeline;
179
256
180
257
void initializeLegalizationPipeline () {
181
- LegalizationPipeline.push_back (fixI8TruncUseChain);
258
+ LegalizationPipeline.push_back (upcastI8AllocasAndUses);
259
+ LegalizationPipeline.push_back (fixI8UseChain);
182
260
LegalizationPipeline.push_back (downcastI64toI32InsertExtractElements);
183
261
LegalizationPipeline.push_back (legalizeFreeze);
184
262
}
0 commit comments