Skip to content

Commit 7a75b54

Browse files
[SYCL] Fix execution graph cleanup on memory object destruction (#1065)
The previous algorithm deleted a visited node if it had no dependencies left, but it only visited the immediate users of deleted nodes while leaving all potential indirect users unchanged. This resulted in leftover dependencies on dead memory objects. The new algorithm addresses this problem by traversing the entire graph component of the memory object being deleted instead. Signed-off-by: Sergey Semenov <[email protected]>
1 parent 4445462 commit 7a75b54

File tree

4 files changed

+189
-51
lines changed

4 files changed

+189
-51
lines changed

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <cstdlib>
1919
#include <fstream>
20+
#include <map>
2021
#include <memory>
2122
#include <queue>
2223
#include <set>
@@ -633,43 +634,71 @@ Scheduler::GraphBuilder::addCG(std::unique_ptr<detail::CG> CommandGroup,
633634
}
634635

635636
void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
636-
if (Record->MAllocaCommands.empty())
637+
std::vector<AllocaCommandBase *> &AllocaCommands = Record->MAllocaCommands;
638+
if (AllocaCommands.empty())
637639
return;
638640

639-
std::queue<Command *> RemoveQueue;
641+
std::queue<Command *> ToVisit;
640642
std::set<Command *> Visited;
643+
std::vector<Command *> CmdsToDelete;
644+
// First, mark all allocas for deletion and their direct users for traversal
645+
// Dependencies of the users will be cleaned up during the traversal
646+
for (Command *AllocaCmd : AllocaCommands) {
647+
Visited.insert(AllocaCmd);
648+
for (Command *UserCmd : AllocaCmd->MUsers)
649+
ToVisit.push(UserCmd);
650+
CmdsToDelete.push_back(AllocaCmd);
651+
// These commands will be deleted later, clear users now to avoid
652+
// updating them during edge removal
653+
AllocaCmd->MUsers.clear();
654+
}
641655

642-
// TODO: release commands need special handling here as they are not reachable
643-
// from alloca commands
644-
645-
for (AllocaCommandBase *AllocaCmd : Record->MAllocaCommands) {
646-
if (Visited.find(AllocaCmd) == Visited.end())
647-
RemoveQueue.push(AllocaCmd);
648-
// Use BFS to find and process all users of removal candidate
649-
while (!RemoveQueue.empty()) {
650-
Command *CandidateCommand = RemoveQueue.front();
651-
RemoveQueue.pop();
652-
653-
if (Visited.insert(CandidateCommand).second) {
654-
for (Command *UserCmd : CandidateCommand->MUsers) {
655-
// As candidate command is about to be freed, we need
656-
// to remove it from dependency list of other commands.
657-
auto NewEnd =
658-
std::remove_if(UserCmd->MDeps.begin(), UserCmd->MDeps.end(),
659-
[CandidateCommand](const DepDesc &Dep) {
660-
return Dep.MDepCommand == CandidateCommand;
661-
});
662-
UserCmd->MDeps.erase(NewEnd, UserCmd->MDeps.end());
663-
664-
// Commands that have no unsatisfied dependencies can be executed
665-
// and are good candidates for clean up.
666-
if (UserCmd->MDeps.empty())
667-
RemoveQueue.push(UserCmd);
668-
}
669-
CandidateCommand->getEvent()->setCommand(nullptr);
670-
delete CandidateCommand;
671-
}
656+
// Traverse the graph using BFS
657+
while (!ToVisit.empty()) {
658+
Command *Cmd = ToVisit.front();
659+
ToVisit.pop();
660+
661+
if (!Visited.insert(Cmd).second)
662+
continue;
663+
664+
for (Command *UserCmd : Cmd->MUsers)
665+
ToVisit.push(UserCmd);
666+
667+
// Delete all dependencies on any allocations being removed
668+
// Track which commands should have their users updated
669+
std::map<Command *, bool> ShouldBeUpdated;
670+
auto NewEnd = std::remove_if(
671+
Cmd->MDeps.begin(), Cmd->MDeps.end(), [&](const DepDesc &Dep) {
672+
if (std::find(AllocaCommands.begin(), AllocaCommands.end(),
673+
Dep.MAllocaCmd) != AllocaCommands.end()) {
674+
ShouldBeUpdated.insert({Dep.MDepCommand, true});
675+
return true;
676+
}
677+
ShouldBeUpdated[Dep.MDepCommand] = false;
678+
return false;
679+
});
680+
Cmd->MDeps.erase(NewEnd, Cmd->MDeps.end());
681+
682+
// Update users of removed dependencies
683+
for (auto DepCmdIt : ShouldBeUpdated) {
684+
if (!DepCmdIt.second)
685+
continue;
686+
std::vector<Command *> &DepUsers = DepCmdIt.first->MUsers;
687+
DepUsers.erase(std::remove(DepUsers.begin(), DepUsers.end(), Cmd),
688+
DepUsers.end());
672689
}
690+
691+
// If all dependencies have been removed this way, mark the command for
692+
// deletion
693+
if (Cmd->MDeps.empty()) {
694+
CmdsToDelete.push_back(Cmd);
695+
Cmd->MUsers.clear();
696+
}
697+
}
698+
699+
for (Command *Cmd : CmdsToDelete) {
700+
Cmd->getEvent()->setCommand(nullptr);
701+
delete Cmd;
673702
}
674703
}
675704

sycl/test/scheduler/FakeCommand.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <CL/sycl.hpp>
2+
3+
// A fake command class used for testing
4+
class FakeCommand : public cl::sycl::detail::Command {
5+
public:
6+
FakeCommand(cl::sycl::detail::QueueImplPtr Queue,
7+
cl::sycl::detail::Requirement Req)
8+
: Command{cl::sycl::detail::Command::ALLOCA, Queue},
9+
MRequirement{std::move(Req)} {}
10+
11+
void printDot(std::ostream &Stream) const override {}
12+
13+
const cl::sycl::detail::Requirement *getRequirement() const final {
14+
return &MRequirement;
15+
};
16+
17+
cl_int enqueueImp() override { return MRetVal; }
18+
19+
cl_int MRetVal = CL_SUCCESS;
20+
21+
protected:
22+
cl::sycl::detail::Requirement MRequirement;
23+
};

sycl/test/scheduler/LeafLimit.cpp

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,12 @@
66
#include <memory>
77
#include <vector>
88

9+
#include "FakeCommand.hpp"
10+
911
// This test checks the leaf limit imposed on the execution graph
1012

1113
using namespace cl::sycl;
1214

13-
class FakeCommand : public detail::Command {
14-
public:
15-
FakeCommand(detail::QueueImplPtr Queue, detail::Requirement Req)
16-
: Command{detail::Command::ALLOCA, Queue}, MRequirement{std::move(Req)} {}
17-
18-
void printDot(std::ostream &Stream) const override {}
19-
20-
const detail::Requirement *getRequirement() const final {
21-
return &MRequirement;
22-
};
23-
24-
cl_int enqueueImp() override { return MRetVal; }
25-
26-
cl_int MRetVal = CL_SUCCESS;
27-
28-
protected:
29-
detail::Requirement MRequirement;
30-
};
31-
3215
class TestScheduler : public detail::Scheduler {
3316
public:
3417
void AddNodeToLeaves(detail::MemObjRecord *Rec, detail::Command *Cmd,
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: %t.out
3+
#include <CL/sycl.hpp>
4+
5+
#include <functional>
6+
#include <memory>
7+
#include <utility>
8+
9+
#include "FakeCommand.hpp"
10+
11+
// This test checks that the execution graph cleanup on memory object
12+
// destruction traverses the entire graph, rather than only the immediate users
13+
// of deleted commands.
14+
15+
using namespace cl::sycl;
16+
17+
class TestScheduler : public detail::Scheduler {
18+
public:
19+
void cleanupCommandsForRecord(detail::MemObjRecord *Rec) {
20+
MGraphBuilder.cleanupCommandsForRecord(Rec);
21+
}
22+
23+
void removeRecordForMemObj(detail::SYCLMemObjI *MemObj) {
24+
MGraphBuilder.removeRecordForMemObj(MemObj);
25+
}
26+
27+
detail::MemObjRecord *
28+
getOrInsertMemObjRecord(const detail::QueueImplPtr &Queue,
29+
detail::Requirement *Req) {
30+
return MGraphBuilder.getOrInsertMemObjRecord(Queue, Req);
31+
}
32+
};
33+
34+
class FakeCommandWithCallback : public FakeCommand {
35+
public:
36+
FakeCommandWithCallback(detail::QueueImplPtr Queue, detail::Requirement Req,
37+
std::function<void()> Callback)
38+
: FakeCommand(Queue, Req), MCallback(std::move(Callback)) {}
39+
40+
~FakeCommandWithCallback() override { MCallback(); }
41+
42+
protected:
43+
std::function<void()> MCallback;
44+
};
45+
46+
template <typename MemObjT>
47+
detail::Requirement getFakeRequirement(const MemObjT &MemObj) {
48+
return {{0, 0, 0},
49+
{0, 0, 0},
50+
{0, 0, 0},
51+
access::mode::read_write,
52+
detail::getSyclObjImpl(MemObj).get(),
53+
0,
54+
0,
55+
0};
56+
}
57+
58+
void addEdge(detail::Command *User, detail::Command *Dep,
59+
detail::AllocaCommandBase *Alloca) {
60+
User->addDep(detail::DepDesc{Dep, User->getRequirement(), Alloca});
61+
Dep->addUser(User);
62+
}
63+
64+
int main() {
65+
TestScheduler TS;
66+
queue Queue;
67+
buffer<int, 1> BufA(range<1>(1));
68+
buffer<int, 1> BufB(range<1>(1));
69+
detail::Requirement FakeReqA = getFakeRequirement(BufA);
70+
detail::Requirement FakeReqB = getFakeRequirement(BufB);
71+
detail::MemObjRecord *RecA =
72+
TS.getOrInsertMemObjRecord(detail::getSyclObjImpl(Queue), &FakeReqA);
73+
74+
// Create 2 fake allocas, one of which will be cleaned up
75+
detail::AllocaCommand *FakeAllocaA =
76+
new detail::AllocaCommand(detail::getSyclObjImpl(Queue), FakeReqA);
77+
std::unique_ptr<detail::AllocaCommand> FakeAllocaB{
78+
new detail::AllocaCommand(detail::getSyclObjImpl(Queue), FakeReqB)};
79+
RecA->MAllocaCommands.push_back(FakeAllocaA);
80+
81+
// Create a direct user of both allocas
82+
std::unique_ptr<FakeCommand> FakeDirectUser{
83+
new FakeCommand(detail::getSyclObjImpl(Queue), FakeReqA)};
84+
addEdge(FakeDirectUser.get(), FakeAllocaA, FakeAllocaA);
85+
addEdge(FakeDirectUser.get(), FakeAllocaB.get(), FakeAllocaB.get());
86+
87+
// Create an indirect user of the soon-to-be deleted alloca
88+
bool IndirectUserDeleted = false;
89+
std::function<void()> Callback = [&]() { IndirectUserDeleted = true; };
90+
FakeCommand *FakeIndirectUser = new FakeCommandWithCallback(
91+
detail::getSyclObjImpl(Queue), FakeReqA, Callback);
92+
addEdge(FakeIndirectUser, FakeDirectUser.get(), FakeAllocaA);
93+
94+
TS.cleanupCommandsForRecord(RecA);
95+
TS.removeRecordForMemObj(detail::getSyclObjImpl(BufA).get());
96+
97+
// Check that the direct user has been left with the second alloca
98+
// as the only dependency, while the indirect user has been cleaned up.
99+
assert(FakeDirectUser->MUsers.size() == 0);
100+
assert(FakeDirectUser->MDeps.size() == 1);
101+
assert(FakeDirectUser->MDeps[0].MDepCommand == FakeAllocaB.get());
102+
assert(IndirectUserDeleted);
103+
}

0 commit comments

Comments
 (0)