Skip to content

Commit 747d8f3

Browse files
authored
[SandboxVec][DAG] Implement PredIterator (#111604)
This patch implements an iterator for iterating over both use-def and mem dependencies of MemDGNodes.
1 parent c2063de commit 747d8f3

File tree

3 files changed

+158
-0
lines changed

3 files changed

+158
-0
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,54 @@ enum class DGNodeID {
4040
MemDGNode,
4141
};
4242

43+
class DGNode;
44+
class MemDGNode;
45+
class DependencyGraph;
46+
47+
/// While OpIt points to a Value that is not an Instruction keep incrementing
48+
/// it. \Returns the first iterator that points to an Instruction, or end.
49+
[[nodiscard]] static User::op_iterator skipNonInstr(User::op_iterator OpIt,
50+
User::op_iterator OpItE) {
51+
while (OpIt != OpItE && !isa<Instruction>((*OpIt).get()))
52+
++OpIt;
53+
return OpIt;
54+
}
55+
56+
/// Iterate over both def-use and mem dependencies.
57+
class PredIterator {
58+
User::op_iterator OpIt;
59+
User::op_iterator OpItE;
60+
DenseSet<MemDGNode *>::iterator MemIt;
61+
DGNode *N = nullptr;
62+
DependencyGraph *DAG = nullptr;
63+
64+
PredIterator(const User::op_iterator &OpIt, const User::op_iterator &OpItE,
65+
const DenseSet<MemDGNode *>::iterator &MemIt, DGNode *N,
66+
DependencyGraph &DAG)
67+
: OpIt(OpIt), OpItE(OpItE), MemIt(MemIt), N(N), DAG(&DAG) {}
68+
PredIterator(const User::op_iterator &OpIt, const User::op_iterator &OpItE,
69+
DGNode *N, DependencyGraph &DAG)
70+
: OpIt(OpIt), OpItE(OpItE), N(N), DAG(&DAG) {}
71+
friend class DGNode; // For constructor
72+
friend class MemDGNode; // For constructor
73+
74+
public:
75+
using difference_type = std::ptrdiff_t;
76+
using value_type = DGNode *;
77+
using pointer = value_type *;
78+
using reference = value_type &;
79+
using iterator_category = std::input_iterator_tag;
80+
value_type operator*();
81+
PredIterator &operator++();
82+
PredIterator operator++(int) {
83+
auto Copy = *this;
84+
++(*this);
85+
return Copy;
86+
}
87+
bool operator==(const PredIterator &Other) const;
88+
bool operator!=(const PredIterator &Other) const { return !(*this == Other); }
89+
};
90+
4391
/// A DependencyGraph Node that points to an Instruction and contains memory
4492
/// dependency edges.
4593
class DGNode {
@@ -63,6 +111,23 @@ class DGNode {
63111
virtual ~DGNode() = default;
64112
/// \Returns true if this is before \p Other in program order.
65113
bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
114+
using iterator = PredIterator;
115+
virtual iterator preds_begin(DependencyGraph &DAG) {
116+
return PredIterator(skipNonInstr(I->op_begin(), I->op_end()), I->op_end(),
117+
this, DAG);
118+
}
119+
virtual iterator preds_end(DependencyGraph &DAG) {
120+
return PredIterator(I->op_end(), I->op_end(), this, DAG);
121+
}
122+
iterator preds_begin(DependencyGraph &DAG) const {
123+
return const_cast<DGNode *>(this)->preds_begin(DAG);
124+
}
125+
iterator preds_end(DependencyGraph &DAG) const {
126+
return const_cast<DGNode *>(this)->preds_end(DAG);
127+
}
128+
iterator_range<iterator> preds(DependencyGraph &DAG) const {
129+
return make_range(preds_begin(DAG), preds_end(DAG));
130+
}
66131

67132
static bool isStackSaveOrRestoreIntrinsic(Instruction *I) {
68133
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
@@ -145,6 +210,14 @@ class MemDGNode final : public DGNode {
145210
static bool classof(const DGNode *Other) {
146211
return Other->SubclassID == DGNodeID::MemDGNode;
147212
}
213+
iterator preds_begin(DependencyGraph &DAG) override {
214+
auto OpEndIt = I->op_end();
215+
return PredIterator(skipNonInstr(I->op_begin(), OpEndIt), OpEndIt,
216+
MemPreds.begin(), this, DAG);
217+
}
218+
iterator preds_end(DependencyGraph &DAG) override {
219+
return PredIterator(I->op_end(), I->op_end(), MemPreds.end(), this, DAG);
220+
}
148221
/// \Returns the previous Mem DGNode in instruction order.
149222
MemDGNode *getPrevNode() const { return PrevMemN; }
150223
/// \Returns the next Mem DGNode in instruction order.

llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,54 @@
88

99
#include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
1010
#include "llvm/ADT/ArrayRef.h"
11+
#include "llvm/SandboxIR/Instruction.h"
1112
#include "llvm/SandboxIR/Utils.h"
1213

1314
namespace llvm::sandboxir {
1415

16+
PredIterator::value_type PredIterator::operator*() {
17+
// If it's a DGNode then we dereference the operand iterator.
18+
if (!isa<MemDGNode>(N)) {
19+
assert(OpIt != OpItE && "Can't dereference end iterator!");
20+
return DAG->getNode(cast<Instruction>((Value *)*OpIt));
21+
}
22+
// It's a MemDGNode, so we check if we return either the use-def operand,
23+
// or a mem predecessor.
24+
if (OpIt != OpItE)
25+
return DAG->getNode(cast<Instruction>((Value *)*OpIt));
26+
assert(MemIt != cast<MemDGNode>(N)->memPreds().end() &&
27+
"Cant' dereference end iterator!");
28+
return *MemIt;
29+
}
30+
31+
PredIterator &PredIterator::operator++() {
32+
// If it's a DGNode then we increment the use-def iterator.
33+
if (!isa<MemDGNode>(N)) {
34+
assert(OpIt != OpItE && "Already at end!");
35+
++OpIt;
36+
// Skip operands that are not instructions.
37+
OpIt = skipNonInstr(OpIt, OpItE);
38+
return *this;
39+
}
40+
// It's a MemDGNode, so if we are not at the end of the use-def iterator we
41+
// need to first increment that.
42+
if (OpIt != OpItE) {
43+
++OpIt;
44+
// Skip operands that are not instructions.
45+
OpIt = skipNonInstr(OpIt, OpItE);
46+
return *this;
47+
}
48+
assert(MemIt != cast<MemDGNode>(N)->memPreds().end() && "Already at end!");
49+
++MemIt;
50+
return *this;
51+
}
52+
53+
bool PredIterator::operator==(const PredIterator &Other) const {
54+
assert(DAG == Other.DAG && "Iterators of different DAGs!");
55+
assert(N == Other.N && "Iterators of different nodes!");
56+
return OpIt == Other.OpIt && MemIt == Other.MemIt;
57+
}
58+
1559
#ifndef NDEBUG
1660
void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
1761
I->dumpOS(OS);

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,53 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
240240
EXPECT_TRUE(N1->hasMemPred(N0));
241241
EXPECT_FALSE(N0->hasMemPred(N1));
242242

243+
// Check preds().
244+
EXPECT_TRUE(N0->preds(DAG).empty());
245+
EXPECT_THAT(N1->preds(DAG), testing::ElementsAre(N0));
246+
243247
// Check memPreds().
244248
EXPECT_TRUE(N0->memPreds().empty());
245249
EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
246250
EXPECT_TRUE(N2->memPreds().empty());
247251
}
248252

253+
TEST_F(DependencyGraphTest, Preds) {
254+
parseIR(C, R"IR(
255+
declare ptr @bar(i8)
256+
define i8 @foo(i8 %v0, i8 %v1) {
257+
%add0 = add i8 %v0, %v0
258+
%add1 = add i8 %v1, %v1
259+
%add2 = add i8 %add0, %add1
260+
%ptr = call ptr @bar(i8 %add1)
261+
store i8 %add2, ptr %ptr
262+
ret i8 %add2
263+
}
264+
)IR");
265+
llvm::Function *LLVMF = &*M->getFunction("foo");
266+
sandboxir::Context Ctx(C);
267+
auto *F = Ctx.createFunction(LLVMF);
268+
auto *BB = &*F->begin();
269+
auto It = BB->begin();
270+
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
271+
DAG.extend({&*BB->begin(), BB->getTerminator()});
272+
273+
auto *AddN0 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
274+
auto *AddN1 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
275+
auto *AddN2 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
276+
auto *CallN = DAG.getNode(cast<sandboxir::CallInst>(&*It++));
277+
auto *StN = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
278+
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
279+
280+
// Check preds().
281+
EXPECT_THAT(AddN0->preds(DAG), testing::ElementsAre());
282+
EXPECT_THAT(AddN1->preds(DAG), testing::ElementsAre());
283+
EXPECT_THAT(AddN2->preds(DAG), testing::ElementsAre(AddN0, AddN1));
284+
EXPECT_THAT(CallN->preds(DAG), testing::ElementsAre(AddN1));
285+
EXPECT_THAT(StN->preds(DAG),
286+
testing::UnorderedElementsAre(CallN, CallN, AddN2));
287+
EXPECT_THAT(RetN->preds(DAG), testing::ElementsAre(AddN2));
288+
}
289+
249290
TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) {
250291
parseIR(C, R"IR(
251292
define void @foo(ptr %ptr, i8 %v0, i8 %v1) {

0 commit comments

Comments
 (0)