Skip to content

Commit 8d825cb

Browse files
authored
[DirectX] Address PR comments to #131221 (#131706)
- [x] [Don't include static inside anonymous namespace](#131221 (comment)) - [x] [Use DenseMap](#131221 (comment)) - [x] [remove {}](#131221 (comment)) - [x] [remove std::stack with llvm::reverse of SmallVector](#131221 (comment)) - [x] [replace std::vector with llvm::SmallVector](#131221 (comment)) - [x] [Remove legalize comment block](#131221 (comment)) and [double comment block](#131221 (comment)) - [x] [Refactor fixI8TruncUseChain to remove nesting](#131221 (comment)) - [x] [add asserts on assumptions](#131221 (comment))
1 parent 969ac10 commit 8d825cb

File tree

1 file changed

+62
-62
lines changed

1 file changed

+62
-62
lines changed

llvm/lib/Target/DirectX/DXILLegalizePass.cpp

Lines changed: 62 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,7 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===---------------------------------------------------------------------===//
8-
//===---------------------------------------------------------------------===//
9-
///
10-
/// \file This file contains a pass to remove i8 truncations and i64 extract
11-
/// and insert elements.
12-
///
13-
//===----------------------------------------------------------------------===//
8+
149
#include "DXILLegalizePass.h"
1510
#include "DirectX.h"
1611
#include "llvm/IR/Function.h"
@@ -20,37 +15,24 @@
2015
#include "llvm/Pass.h"
2116
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
2217
#include <functional>
23-
#include <map>
24-
#include <stack>
25-
#include <vector>
2618

2719
#define DEBUG_TYPE "dxil-legalize"
2820

2921
using namespace llvm;
30-
namespace {
3122

3223
static void fixI8TruncUseChain(Instruction &I,
33-
std::stack<Instruction *> &ToRemove,
34-
std::map<Value *, Value *> &ReplacedValues) {
35-
36-
auto *Cmp = dyn_cast<CmpInst>(&I);
24+
SmallVectorImpl<Instruction *> &ToRemove,
25+
DenseMap<Value *, Value *> &ReplacedValues) {
3726

38-
if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
39-
if (Trunc->getDestTy()->isIntegerTy(8)) {
40-
ReplacedValues[Trunc] = Trunc->getOperand(0);
41-
ToRemove.push(Trunc);
42-
}
43-
} else if (I.getType()->isIntegerTy(8) ||
44-
(Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) {
45-
IRBuilder<> Builder(&I);
46-
47-
std::vector<Value *> NewOperands;
27+
auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
4828
Type *InstrType = IntegerType::get(I.getContext(), 32);
29+
4930
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
5031
Value *Op = I.getOperand(OpIdx);
5132
if (ReplacedValues.count(Op))
5233
InstrType = ReplacedValues[Op]->getType();
5334
}
35+
5436
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
5537
Value *Op = I.getOperand(OpIdx);
5638
if (ReplacedValues.count(Op))
@@ -61,47 +43,68 @@ static void fixI8TruncUseChain(Instruction &I,
6143
// Note: options here are sext or sextOrTrunc.
6244
// Since i8 isn't supported, we assume new values
6345
// will always have a higher bitness.
46+
assert(NewBitWidth > Value.getBitWidth() &&
47+
"Replacement's BitWidth should be larger than Current.");
6448
APInt NewValue = Value.sext(NewBitWidth);
6549
NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
6650
} else {
6751
assert(!Op->getType()->isIntegerTy(8));
6852
NewOperands.push_back(Op);
6953
}
7054
}
71-
72-
Value *NewInst = nullptr;
73-
if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
74-
NewInst =
75-
Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
76-
77-
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
78-
if (OBO->hasNoSignedWrap())
79-
cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
80-
if (OBO->hasNoUnsignedWrap())
81-
cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
82-
}
83-
} else if (Cmp) {
84-
NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0],
85-
NewOperands[1]);
86-
Cmp->replaceAllUsesWith(NewInst);
55+
};
56+
IRBuilder<> Builder(&I);
57+
if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
58+
if (Trunc->getDestTy()->isIntegerTy(8)) {
59+
ReplacedValues[Trunc] = Trunc->getOperand(0);
60+
ToRemove.push_back(Trunc);
61+
return;
8762
}
63+
}
8864

89-
if (NewInst) {
90-
ReplacedValues[&I] = NewInst;
91-
ToRemove.push(&I);
65+
if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
66+
if (!I.getType()->isIntegerTy(8))
67+
return;
68+
SmallVector<Value *> NewOperands;
69+
ProcessOperands(NewOperands);
70+
Value *NewInst =
71+
Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
72+
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
73+
if (OBO->hasNoSignedWrap())
74+
cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
75+
if (OBO->hasNoUnsignedWrap())
76+
cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
9277
}
93-
} else if (auto *Cast = dyn_cast<CastInst>(&I)) {
78+
ReplacedValues[BO] = NewInst;
79+
ToRemove.push_back(BO);
80+
return;
81+
}
82+
83+
if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
84+
if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
85+
return;
86+
SmallVector<Value *> NewOperands;
87+
ProcessOperands(NewOperands);
88+
Value *NewInst =
89+
Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]);
90+
Cmp->replaceAllUsesWith(NewInst);
91+
ReplacedValues[Cmp] = NewInst;
92+
ToRemove.push_back(Cmp);
93+
return;
94+
}
95+
96+
if (auto *Cast = dyn_cast<CastInst>(&I)) {
9497
if (Cast->getSrcTy()->isIntegerTy(8)) {
95-
ToRemove.push(Cast);
98+
ToRemove.push_back(Cast);
9699
Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
97100
}
98101
}
99102
}
100103

101104
static void
102105
downcastI64toI32InsertExtractElements(Instruction &I,
103-
std::stack<Instruction *> &ToRemove,
104-
std::map<Value *, Value *> &) {
106+
SmallVectorImpl<Instruction *> &ToRemove,
107+
DenseMap<Value *, Value *> &) {
105108

106109
if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
107110
Value *Idx = Extract->getIndexOperand();
@@ -115,7 +118,7 @@ downcastI64toI32InsertExtractElements(Instruction &I,
115118
Extract->getVectorOperand(), Idx32, Extract->getName());
116119

117120
Extract->replaceAllUsesWith(NewExtract);
118-
ToRemove.push(Extract);
121+
ToRemove.push_back(Extract);
119122
}
120123
}
121124

@@ -132,38 +135,35 @@ downcastI64toI32InsertExtractElements(Instruction &I,
132135
Insert->getName());
133136

134137
Insert->replaceAllUsesWith(Insert32Index);
135-
ToRemove.push(Insert);
138+
ToRemove.push_back(Insert);
136139
}
137140
}
138141
}
139142

143+
namespace {
140144
class DXILLegalizationPipeline {
141145

142146
public:
143147
DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
144148

145149
bool runLegalizationPipeline(Function &F) {
146-
std::stack<Instruction *> ToRemove;
147-
std::map<Value *, Value *> ReplacedValues;
150+
SmallVector<Instruction *> ToRemove;
151+
DenseMap<Value *, Value *> ReplacedValues;
148152
for (auto &I : instructions(F)) {
149-
for (auto &LegalizationFn : LegalizationPipeline) {
153+
for (auto &LegalizationFn : LegalizationPipeline)
150154
LegalizationFn(I, ToRemove, ReplacedValues);
151-
}
152155
}
153-
bool MadeChanges = !ToRemove.empty();
154156

155-
while (!ToRemove.empty()) {
156-
Instruction *I = ToRemove.top();
157-
I->eraseFromParent();
158-
ToRemove.pop();
159-
}
157+
for (auto *Inst : reverse(ToRemove))
158+
Inst->eraseFromParent();
160159

161-
return MadeChanges;
160+
return !ToRemove.empty();
162161
}
163162

164163
private:
165-
std::vector<std::function<void(Instruction &, std::stack<Instruction *> &,
166-
std::map<Value *, Value *> &)>>
164+
SmallVector<
165+
std::function<void(Instruction &, SmallVectorImpl<Instruction *> &,
166+
DenseMap<Value *, Value *> &)>>
167167
LegalizationPipeline;
168168

169169
void initializeLegalizationPipeline() {

0 commit comments

Comments
 (0)