Skip to content

Commit c9cfc73

Browse files
committed
[SYCL][Graph] Addressing PR feedback
- Style fixes - Make MRoots weak_ptrs instead of shared
1 parent 508ee90 commit c9cfc73

File tree

3 files changed

+39
-23
lines changed

3 files changed

+39
-23
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ void sortTopological(std::shared_ptr<node_impl> NodeImpl,
8282
for (auto &Succ : NodeImpl->MSuccessors) {
8383
// Check if we've already scheduled this node
8484
auto NextNode = Succ.lock();
85-
if (std::find(Schedule.begin(), Schedule.end(), NextNode) == Schedule.end())
85+
if (std::find(Schedule.begin(), Schedule.end(), NextNode) ==
86+
Schedule.end()) {
8687
sortTopological(NextNode, Schedule);
88+
}
8789
}
8890

8991
Schedule.push_front(NodeImpl);
@@ -93,7 +95,7 @@ void sortTopological(std::shared_ptr<node_impl> NodeImpl,
9395
void exec_graph_impl::schedule() {
9496
if (MSchedule.empty()) {
9597
for (auto &Node : MGraphImpl->MRoots) {
96-
sortTopological(Node, MSchedule);
98+
sortTopological(Node.lock(), MSchedule);
9799
}
98100
}
99101
}
@@ -264,11 +266,14 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
264266
// If any of this node's successors have this requirement then we skip
265267
// adding the current node as a dependency.
266268
for (auto &Succ : Node->MSuccessors) {
267-
if (Succ.lock()->hasRequirement(Req))
269+
if (Succ.lock()->hasRequirement(Req)) {
268270
ShouldAddDep = false;
271+
break;
272+
}
269273
}
270-
if (ShouldAddDep)
274+
if (ShouldAddDep) {
271275
UniqueDeps.insert(Node);
276+
}
272277
}
273278
}
274279
}
@@ -328,7 +333,7 @@ void graph_impl::searchDepthFirst(
328333

329334
for (auto &Root : MRoots) {
330335
std::deque<std::shared_ptr<node_impl>> NodeStack;
331-
if (visitNodeDepthFirst(Root, VisitedNodes, NodeStack, NodeFunc)) {
336+
if (visitNodeDepthFirst(Root.lock(), VisitedNodes, NodeStack, NodeFunc)) {
332337
break;
333338
}
334339
}
@@ -374,8 +379,9 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
374379
SrcFound |= Node == Src;
375380
DestFound |= Node == Dest;
376381

377-
if (SrcFound && DestFound)
382+
if (SrcFound && DestFound) {
378383
break;
384+
}
379385
}
380386

381387
if (!SrcFound) {

sycl/source/detail/graph_impl.hpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ class node_impl {
183183
/// @param CompareContentOnly Skip comparisons related to graph structure,
184184
/// compare only the type and command groups of the nodes
185185
/// @return True if the two nodes are similar
186-
bool isSimilar(std::shared_ptr<node_impl> Node,
187-
bool CompareContentOnly = false) {
186+
bool isSimilar(const std::shared_ptr<node_impl> &Node,
187+
bool CompareContentOnly = false) const {
188188
if (!CompareContentOnly) {
189189
if (MSuccessors.size() != Node->MSuccessors.size())
190190
return false;
@@ -379,7 +379,8 @@ class graph_impl {
379379
sycl::device getDevice() const { return MDevice; }
380380

381381
/// List of root nodes.
382-
std::set<std::shared_ptr<node_impl>> MRoots;
382+
std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
383+
MRoots;
383384

384385
/// Storage for all nodes contained within a graph. Nodes are connected to
385386
/// each other via weak_ptrs and so do not extend each other's lifetimes.
@@ -433,8 +434,8 @@ class graph_impl {
433434
/// @param NodeA pointer to the first node for comparison
434435
/// @param NodeB pointer to the second node for comparison
435436
/// @return true is same structure found, false otherwise
436-
static bool checkNodeRecursive(std::shared_ptr<node_impl> NodeA,
437-
std::shared_ptr<node_impl> NodeB) {
437+
static bool checkNodeRecursive(const std::shared_ptr<node_impl> &NodeA,
438+
const std::shared_ptr<node_impl> &NodeB) {
438439
size_t FoundCnt = 0;
439440
for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
440441
for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
@@ -509,10 +510,13 @@ class graph_impl {
509510
}
510511

511512
size_t RootsFound = 0;
512-
for (std::shared_ptr<node_impl> NodeA : MRoots) {
513-
for (std::shared_ptr<node_impl> NodeB : Graph->MRoots) {
514-
if (NodeA->isSimilar(NodeB)) {
515-
if (checkNodeRecursive(NodeA, NodeB)) {
513+
for (std::weak_ptr<node_impl> NodeA : MRoots) {
514+
for (std::weak_ptr<node_impl> NodeB : Graph->MRoots) {
515+
auto NodeALocked = NodeA.lock();
516+
auto NodeBLocked = NodeB.lock();
517+
518+
if (NodeALocked->isSimilar(NodeBLocked)) {
519+
if (checkNodeRecursive(NodeALocked, NodeBLocked)) {
516520
RootsFound++;
517521
break;
518522
}

sycl/unittests/Extensions/CommandGraph.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,8 @@ TEST_F(CommandGraphTest, AddNode) {
497497
ASSERT_NE(sycl::detail::getSyclObjImpl(Node1), nullptr);
498498
ASSERT_FALSE(sycl::detail::getSyclObjImpl(Node1)->isEmpty());
499499
ASSERT_EQ(GraphImpl->MRoots.size(), 1lu);
500-
ASSERT_EQ(*GraphImpl->MRoots.begin(), sycl::detail::getSyclObjImpl(Node1));
500+
ASSERT_EQ((*GraphImpl->MRoots.begin()).lock(),
501+
sycl::detail::getSyclObjImpl(Node1));
501502
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.empty());
502503
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MPredecessors.empty());
503504

@@ -1269,7 +1270,8 @@ TEST_F(CommandGraphTest, EnqueueBarrier) {
12691270
// / \
12701271
// (4) (5)
12711272
ASSERT_EQ(GraphImpl->MRoots.size(), 3lu);
1272-
for (auto Node : GraphImpl->MRoots) {
1273+
for (auto Root : GraphImpl->MRoots) {
1274+
auto Node = Root.lock();
12731275
ASSERT_EQ(Node->MSuccessors.size(), 1lu);
12741276
auto BarrierNode = Node->MSuccessors.front().lock();
12751277
ASSERT_EQ(BarrierNode->MCGType, sycl::detail::CG::Barrier);
@@ -1309,7 +1311,8 @@ TEST_F(CommandGraphTest, EnqueueBarrierMultipleQueues) {
13091311
// / \
13101312
// (4) (5)
13111313
ASSERT_EQ(GraphImpl->MRoots.size(), 3lu);
1312-
for (auto Node : GraphImpl->MRoots) {
1314+
for (auto Root : GraphImpl->MRoots) {
1315+
auto Node = Root.lock();
13131316
ASSERT_EQ(Node->MSuccessors.size(), 1lu);
13141317
auto BarrierNode = Node->MSuccessors.front().lock();
13151318
ASSERT_EQ(BarrierNode->MCGType, sycl::detail::CG::Barrier);
@@ -1352,7 +1355,8 @@ TEST_F(CommandGraphTest, EnqueueBarrierWaitList) {
13521355
// / \ /
13531356
// (4) (5)
13541357
ASSERT_EQ(GraphImpl->MRoots.size(), 3lu);
1355-
for (auto Node : GraphImpl->MRoots) {
1358+
for (auto Root : GraphImpl->MRoots) {
1359+
auto Node = Root.lock();
13561360
ASSERT_EQ(Node->MSuccessors.size(), 1lu);
13571361
auto SuccNode = Node->MSuccessors.front().lock();
13581362
if (SuccNode->MCGType == sycl::detail::CG::Barrier) {
@@ -1408,7 +1412,8 @@ TEST_F(CommandGraphTest, EnqueueBarrierWaitListMultipleQueues) {
14081412
// \|/
14091413
// (B2)
14101414
ASSERT_EQ(GraphImpl->MRoots.size(), 3lu);
1411-
for (auto Node : GraphImpl->MRoots) {
1415+
for (auto Root : GraphImpl->MRoots) {
1416+
auto Node = Root.lock();
14121417
ASSERT_EQ(Node->MSuccessors.size(), 1lu);
14131418
auto SuccNode = Node->MSuccessors.front().lock();
14141419
if (SuccNode->MCGType == sycl::detail::CG::Barrier) {
@@ -1470,7 +1475,8 @@ TEST_F(CommandGraphTest, EnqueueMultipleBarrier) {
14701475
// / | \
14711476
// (6) (7) (8) (those nodes also have B1 as a predecessor)
14721477
ASSERT_EQ(GraphImpl->MRoots.size(), 3lu);
1473-
for (auto Node : GraphImpl->MRoots) {
1478+
for (auto Root : GraphImpl->MRoots) {
1479+
auto Node = Root.lock();
14741480
ASSERT_EQ(Node->MSuccessors.size(), 1lu);
14751481
auto SuccNode = Node->MSuccessors.front().lock();
14761482
if (SuccNode->MCGType == sycl::detail::CG::Barrier) {
@@ -1824,7 +1830,7 @@ TEST_F(CommandGraphTest, MakeEdgeErrors) {
18241830
auto NodeBImpl = sycl::detail::getSyclObjImpl(NodeB);
18251831

18261832
ASSERT_EQ(GraphImpl->MRoots.size(), 1lu);
1827-
ASSERT_EQ(*(GraphImpl->MRoots.begin()), NodeAImpl);
1833+
ASSERT_EQ((*GraphImpl->MRoots.begin()).lock(), NodeAImpl);
18281834

18291835
ASSERT_EQ(NodeAImpl->MSuccessors.size(), 1lu);
18301836
ASSERT_EQ(NodeAImpl->MPredecessors.size(), 0lu);
@@ -2070,7 +2076,7 @@ TEST_F(MultiThreadGraphTest, RecordAddNodesInOrderQueue) {
20702076
ASSERT_EQ(GraphImpl->MRoots.size(), 1lu);
20712077

20722078
// Check structure graph
2073-
auto CurrentNode = *GraphImpl->MRoots.begin();
2079+
auto CurrentNode = (*GraphImpl->MRoots.begin()).lock();
20742080
for (size_t i = 1; i <= GraphImpl->getNumberOfNodes(); i++) {
20752081
EXPECT_LE(CurrentNode->MSuccessors.size(), 1lu);
20762082

0 commit comments

Comments
 (0)