5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===---------------------------------------------------------------------===//
8
- // ===---------------------------------------------------------------------===//
9
- // /
10
- // / \file This file contains a pass to remove i8 truncations and i64 extract
11
- // / and insert elements.
12
- // /
13
- // ===----------------------------------------------------------------------===//
8
+
14
9
#include " DXILLegalizePass.h"
15
10
#include " DirectX.h"
16
11
#include " llvm/IR/Function.h"
20
15
#include " llvm/Pass.h"
21
16
#include " llvm/Transforms/Utils/BasicBlockUtils.h"
22
17
#include < functional>
23
- #include < map>
24
- #include < stack>
25
- #include < vector>
26
18
27
19
#define DEBUG_TYPE " dxil-legalize"
28
20
29
21
using namespace llvm ;
30
- namespace {
31
22
32
23
static void fixI8TruncUseChain (Instruction &I,
33
- std::stack<Instruction *> &ToRemove,
34
- std::map<Value *, Value *> &ReplacedValues) {
35
-
36
- auto *Cmp = dyn_cast<CmpInst>(&I);
24
+ SmallVectorImpl<Instruction *> &ToRemove,
25
+ DenseMap<Value *, Value *> &ReplacedValues) {
37
26
38
- if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
39
- if (Trunc->getDestTy ()->isIntegerTy (8 )) {
40
- ReplacedValues[Trunc] = Trunc->getOperand (0 );
41
- ToRemove.push (Trunc);
42
- }
43
- } else if (I.getType ()->isIntegerTy (8 ) ||
44
- (Cmp && Cmp->getOperand (0 )->getType ()->isIntegerTy (8 ))) {
45
- IRBuilder<> Builder (&I);
46
-
47
- std::vector<Value *> NewOperands;
27
+ auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
48
28
Type *InstrType = IntegerType::get (I.getContext (), 32 );
29
+
49
30
for (unsigned OpIdx = 0 ; OpIdx < I.getNumOperands (); ++OpIdx) {
50
31
Value *Op = I.getOperand (OpIdx);
51
32
if (ReplacedValues.count (Op))
52
33
InstrType = ReplacedValues[Op]->getType ();
53
34
}
35
+
54
36
for (unsigned OpIdx = 0 ; OpIdx < I.getNumOperands (); ++OpIdx) {
55
37
Value *Op = I.getOperand (OpIdx);
56
38
if (ReplacedValues.count (Op))
@@ -61,47 +43,68 @@ static void fixI8TruncUseChain(Instruction &I,
61
43
// Note: options here are sext or sextOrTrunc.
62
44
// Since i8 isn't supported, we assume new values
63
45
// will always have a higher bitness.
46
+ assert (NewBitWidth > Value.getBitWidth () &&
47
+ " Replacement's BitWidth should be larger than Current." );
64
48
APInt NewValue = Value.sext (NewBitWidth);
65
49
NewOperands.push_back (ConstantInt::get (InstrType, NewValue));
66
50
} else {
67
51
assert (!Op->getType ()->isIntegerTy (8 ));
68
52
NewOperands.push_back (Op);
69
53
}
70
54
}
71
-
72
- Value *NewInst = nullptr ;
73
- if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
74
- NewInst =
75
- Builder.CreateBinOp (BO->getOpcode (), NewOperands[0 ], NewOperands[1 ]);
76
-
77
- if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
78
- if (OBO->hasNoSignedWrap ())
79
- cast<BinaryOperator>(NewInst)->setHasNoSignedWrap ();
80
- if (OBO->hasNoUnsignedWrap ())
81
- cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap ();
82
- }
83
- } else if (Cmp) {
84
- NewInst = Builder.CreateCmp (Cmp->getPredicate (), NewOperands[0 ],
85
- NewOperands[1 ]);
86
- Cmp->replaceAllUsesWith (NewInst);
55
+ };
56
+ IRBuilder<> Builder (&I);
57
+ if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
58
+ if (Trunc->getDestTy ()->isIntegerTy (8 )) {
59
+ ReplacedValues[Trunc] = Trunc->getOperand (0 );
60
+ ToRemove.push_back (Trunc);
61
+ return ;
87
62
}
63
+ }
88
64
89
- if (NewInst) {
90
- ReplacedValues[&I] = NewInst;
91
- ToRemove.push (&I);
65
+ if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
66
+ if (!I.getType ()->isIntegerTy (8 ))
67
+ return ;
68
+ SmallVector<Value *> NewOperands;
69
+ ProcessOperands (NewOperands);
70
+ Value *NewInst =
71
+ Builder.CreateBinOp (BO->getOpcode (), NewOperands[0 ], NewOperands[1 ]);
72
+ if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
73
+ if (OBO->hasNoSignedWrap ())
74
+ cast<BinaryOperator>(NewInst)->setHasNoSignedWrap ();
75
+ if (OBO->hasNoUnsignedWrap ())
76
+ cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap ();
92
77
}
93
- } else if (auto *Cast = dyn_cast<CastInst>(&I)) {
78
+ ReplacedValues[BO] = NewInst;
79
+ ToRemove.push_back (BO);
80
+ return ;
81
+ }
82
+
83
+ if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
84
+ if (!Cmp->getOperand (0 )->getType ()->isIntegerTy (8 ))
85
+ return ;
86
+ SmallVector<Value *> NewOperands;
87
+ ProcessOperands (NewOperands);
88
+ Value *NewInst =
89
+ Builder.CreateCmp (Cmp->getPredicate (), NewOperands[0 ], NewOperands[1 ]);
90
+ Cmp->replaceAllUsesWith (NewInst);
91
+ ReplacedValues[Cmp] = NewInst;
92
+ ToRemove.push_back (Cmp);
93
+ return ;
94
+ }
95
+
96
+ if (auto *Cast = dyn_cast<CastInst>(&I)) {
94
97
if (Cast->getSrcTy ()->isIntegerTy (8 )) {
95
- ToRemove.push (Cast);
98
+ ToRemove.push_back (Cast);
96
99
Cast->replaceAllUsesWith (ReplacedValues[Cast->getOperand (0 )]);
97
100
}
98
101
}
99
102
}
100
103
101
104
static void
102
105
downcastI64toI32InsertExtractElements (Instruction &I,
103
- std::stack <Instruction *> &ToRemove,
104
- std::map <Value *, Value *> &) {
106
+ SmallVectorImpl <Instruction *> &ToRemove,
107
+ DenseMap <Value *, Value *> &) {
105
108
106
109
if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
107
110
Value *Idx = Extract->getIndexOperand ();
@@ -115,7 +118,7 @@ downcastI64toI32InsertExtractElements(Instruction &I,
115
118
Extract->getVectorOperand (), Idx32, Extract->getName ());
116
119
117
120
Extract->replaceAllUsesWith (NewExtract);
118
- ToRemove.push (Extract);
121
+ ToRemove.push_back (Extract);
119
122
}
120
123
}
121
124
@@ -132,38 +135,35 @@ downcastI64toI32InsertExtractElements(Instruction &I,
132
135
Insert->getName ());
133
136
134
137
Insert->replaceAllUsesWith (Insert32Index);
135
- ToRemove.push (Insert);
138
+ ToRemove.push_back (Insert);
136
139
}
137
140
}
138
141
}
139
142
143
+ namespace {
140
144
class DXILLegalizationPipeline {
141
145
142
146
public:
143
147
DXILLegalizationPipeline () { initializeLegalizationPipeline (); }
144
148
145
149
bool runLegalizationPipeline (Function &F) {
146
- std::stack <Instruction *> ToRemove;
147
- std::map <Value *, Value *> ReplacedValues;
150
+ SmallVector <Instruction *> ToRemove;
151
+ DenseMap <Value *, Value *> ReplacedValues;
148
152
for (auto &I : instructions (F)) {
149
- for (auto &LegalizationFn : LegalizationPipeline) {
153
+ for (auto &LegalizationFn : LegalizationPipeline)
150
154
LegalizationFn (I, ToRemove, ReplacedValues);
151
- }
152
155
}
153
- bool MadeChanges = !ToRemove.empty ();
154
156
155
- while (!ToRemove.empty ()) {
156
- Instruction *I = ToRemove.top ();
157
- I->eraseFromParent ();
158
- ToRemove.pop ();
159
- }
157
+ for (auto *Inst : reverse (ToRemove))
158
+ Inst->eraseFromParent ();
160
159
161
- return MadeChanges ;
160
+ return !ToRemove. empty () ;
162
161
}
163
162
164
163
private:
165
- std::vector<std::function<void (Instruction &, std::stack<Instruction *> &,
166
- std::map<Value *, Value *> &)>>
164
+ SmallVector<
165
+ std::function<void (Instruction &, SmallVectorImpl<Instruction *> &,
166
+ DenseMap<Value *, Value *> &)>>
167
167
LegalizationPipeline;
168
168
169
169
void initializeLegalizationPipeline () {
0 commit comments