Skip to content

[SYCL] Exit early while trying to enqueue blocked tasks #2347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
6 changes: 5 additions & 1 deletion sycl/source/detail/scheduler/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,18 @@ class Command {
/// \param Blocking if this argument is true, function will wait for the
/// command to be unblocked before calling enqueueImp.
/// \return true if the command is enqueued.
bool enqueue(EnqueueResultT &EnqueueResult, BlockingT Blocking);
virtual bool enqueue(EnqueueResultT &EnqueueResult, BlockingT Blocking);

bool isFinished();

bool isSuccessfullyEnqueued() const {
return MEnqueueStatus == EnqueueResultT::SyclEnqueueSuccess;
}

bool isEnqueueBlocked() const {
return MEnqueueStatus == EnqueueResultT::SyclEnqueueBlocked;
}

std::shared_ptr<queue_impl> getQueue() const { return MQueue; }

std::shared_ptr<event_impl> getEvent() const { return MEvent; }
Expand Down
32 changes: 9 additions & 23 deletions sycl/source/detail/scheduler/graph_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,33 +58,19 @@ bool Scheduler::GraphProcessor::enqueueCommand(Command *Cmd,
if (!Cmd || Cmd->isSuccessfullyEnqueued())
return true;

// Indicates whether dependency cannot be enqueued
bool BlockedByDep = false;
// Exit early if the command is blocked and the enqueue type is non-blocking
if (Cmd->isEnqueueBlocked() && !Blocking) {
EnqueueResult = EnqueueResultT(EnqueueResultT::SyclEnqueueBlocked, Cmd);
return false;
}

// Recursively enqueue all the dependencies first and
// exit immediately if any of the commands cannot be enqueued.
for (DepDesc &Dep : Cmd->MDeps) {
const bool Enqueued =
enqueueCommand(Dep.MDepCommand, EnqueueResult, Blocking);
if (!Enqueued)
switch (EnqueueResult.MResult) {
case EnqueueResultT::SyclEnqueueFailed:
default:
// Exit immediately if a command fails to avoid enqueueing commands
// result of which will be discarded.
return false;
case EnqueueResultT::SyclEnqueueBlocked:
// If some dependency is blocked from enqueueing remember that, but
// try to enqueue other dependencies(that can be ready for
// enqueueing).
BlockedByDep = true;
break;
}
if (!enqueueCommand(Dep.MDepCommand, EnqueueResult, Blocking))
return false;
}

// Exit if some command is blocked from enqueueing, the EnqueueResult is set
// by the latest dependency which was blocked.
if (BlockedByDep)
return false;

return Cmd->enqueue(EnqueueResult, Blocking);
}

Expand Down
85 changes: 85 additions & 0 deletions sycl/unittests/scheduler/BlockedCommands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "SchedulerTestUtils.hpp"

using namespace cl::sycl;
using namespace testing;

TEST_F(SchedulerTest, BlockedCommands) {
MockCommand MockCmd(detail::getSyclObjImpl(MQueue));
Expand Down Expand Up @@ -45,3 +46,87 @@ TEST_F(SchedulerTest, BlockedCommands) {
Res.MResult == detail::EnqueueResultT::SyclEnqueueSuccess)
<< "The command is expected to be successfully enqueued.\n";
}

TEST_F(SchedulerTest, DontEnqueueDepsIfOneOfThemIsBlocked) {
MockCommand A(detail::getSyclObjImpl(MQueue));
A.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
A.MIsBlockable = true;
A.MRetVal = CL_SUCCESS;

MockCommand B(detail::getSyclObjImpl(MQueue));
B.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
B.MIsBlockable = true;
B.MRetVal = CL_SUCCESS;

MockCommand C(detail::getSyclObjImpl(MQueue));
C.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueBlocked;
C.MIsBlockable = true;

MockCommand D(detail::getSyclObjImpl(MQueue));
D.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
D.MIsBlockable = true;
D.MRetVal = CL_SUCCESS;

addEdge(&A, &B, nullptr);
addEdge(&A, &C, nullptr);
addEdge(&A, &D, nullptr);

// We have such a graph:
//
// A
// / | \
// B C D
//
// If C is blocked, we should not try to enqueue D.
Comment on lines +74 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand it correctly that A depends on B, C and D? Why can't we enqueue D then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed it with @romanovvlad and AFAIU, in the current design there are no benefits of including D, as it is likely to be blocked as well.


EXPECT_CALL(A, enqueue(_, _)).Times(0);
EXPECT_CALL(B, enqueue(_, _)).Times(1);
EXPECT_CALL(C, enqueue(_, _)).Times(0);
EXPECT_CALL(D, enqueue(_, _)).Times(0);

detail::EnqueueResultT Res;
bool Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::NON_BLOCKING);
ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n";
ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueBlocked, Res.MResult)
<< "Result of enqueueing blocked command should be BLOCKED.\n";
ASSERT_EQ(&C, Res.MCmd) << "Expected different failed command.\n";
}

TEST_F(SchedulerTest, EnqueueBlockedCommandEarlyExit) {
MockCommand A(detail::getSyclObjImpl(MQueue));
A.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueBlocked;
A.MIsBlockable = true;

MockCommand B(detail::getSyclObjImpl(MQueue));
B.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
B.MRetVal = CL_OUT_OF_RESOURCES;

addEdge(&A, &B, nullptr);

// We have such a graph:
//
// A -> B
//
// If A is blocked, we should not try to enqueue B.

EXPECT_CALL(A, enqueue(_, _)).Times(0);
EXPECT_CALL(B, enqueue(_, _)).Times(0);

detail::EnqueueResultT Res;
bool Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::NON_BLOCKING);
ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n";
ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueBlocked, Res.MResult)
<< "Result of enqueueing blocked command should be BLOCKED.\n";
ASSERT_EQ(&A, Res.MCmd) << "Expected different failed command.\n";

// But if the enqueue type is blocking we should not exit early.

EXPECT_CALL(A, enqueue(_, _)).Times(0);
EXPECT_CALL(B, enqueue(_, _)).Times(1);

Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::BLOCKING);
ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n";
ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueFailed, Res.MResult)
<< "Result of enqueueing blocked command should be BLOCKED.\n";
ASSERT_EQ(&B, Res.MCmd) << "Expected different failed command.\n";
}
31 changes: 17 additions & 14 deletions sycl/unittests/scheduler/LeafLimit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,45 @@ using namespace cl::sycl;
// overflowed.
TEST_F(SchedulerTest, LeafLimit) {
MockScheduler MS;
std::vector<std::unique_ptr<MockCommand>> LeavesToAdd;
std::unique_ptr<MockCommand> MockDepCmd;

buffer<int, 1> Buf(range<1>(1));
detail::Requirement MockReq = getMockRequirement(Buf);
MockCommand *MockDepCmd =
new MockCommand(detail::getSyclObjImpl(MQueue), MockReq);

MockDepCmd =
std::make_unique<MockCommand>(detail::getSyclObjImpl(MQueue), MockReq);
detail::MemObjRecord *Rec =
MS.getOrInsertMemObjRecord(detail::getSyclObjImpl(MQueue), &MockReq);

// Create commands that will be added as leaves exceeding the limit by 1
std::vector<MockCommand *> LeavesToAdd;
for (std::size_t i = 0; i < Rec->MWriteLeaves.genericCommandsCapacity() + 1;
++i) {
LeavesToAdd.push_back(
new MockCommand(detail::getSyclObjImpl(MQueue), MockReq));
std::make_unique<MockCommand>(detail::getSyclObjImpl(MQueue), MockReq));
}
// Create edges: all soon-to-be leaves are direct users of MockDep
for (auto Leaf : LeavesToAdd) {
MockDepCmd->addUser(Leaf);
Leaf->addDep(detail::DepDesc{MockDepCmd, Leaf->getRequirement(), nullptr});
for (auto &Leaf : LeavesToAdd) {
MockDepCmd->addUser(Leaf.get());
Leaf->addDep(
detail::DepDesc{MockDepCmd.get(), Leaf->getRequirement(), nullptr});
}
// Add edges as leaves and exceed the leaf limit
for (auto LeafPtr : LeavesToAdd) {
MS.addNodeToLeaves(Rec, LeafPtr);
for (auto &LeafPtr : LeavesToAdd) {
MS.addNodeToLeaves(Rec, LeafPtr.get());
}
// Check that the oldest leaf has been removed from the leaf list
// and added as a dependency of the newest one instead
const detail::CircularBuffer<detail::Command *> &Leaves =
Rec->MWriteLeaves.getGenericCommands();
ASSERT_TRUE(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd.front()) ==
Leaves.end());
ASSERT_TRUE(std::find(Leaves.begin(), Leaves.end(),
LeavesToAdd.front().get()) == Leaves.end());
for (std::size_t i = 1; i < LeavesToAdd.size(); ++i) {
assert(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd[i]) !=
assert(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd[i].get()) !=
Leaves.end());
}
MockCommand *OldestLeaf = LeavesToAdd.front();
MockCommand *NewestLeaf = LeavesToAdd.back();
MockCommand *OldestLeaf = LeavesToAdd.front().get();
MockCommand *NewestLeaf = LeavesToAdd.back().get();
ASSERT_EQ(OldestLeaf->MUsers.size(), 1U);
EXPECT_GT(OldestLeaf->MUsers.count(NewestLeaf), 0U);
ASSERT_EQ(NewestLeaf->MDeps.size(), 2U);
Expand Down
23 changes: 21 additions & 2 deletions sycl/unittests/scheduler/SchedulerTestUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <detail/scheduler/scheduler.hpp>

#include <functional>
#include <gmock/gmock.h>

// This header contains a few common classes/methods used in
// execution graph testing.

Expand All @@ -24,12 +26,22 @@ class MockCommand : public cl::sycl::detail::Command {
cl::sycl::detail::Requirement Req,
cl::sycl::detail::Command::CommandType Type =
cl::sycl::detail::Command::RUN_CG)
: Command{Type, Queue}, MRequirement{std::move(Req)} {}
: Command{Type, Queue}, MRequirement{std::move(Req)} {
using namespace testing;
ON_CALL(*this, enqueue(_, _))
.WillByDefault(Invoke(this, &MockCommand::enqueueOrigin));
EXPECT_CALL(*this, enqueue(_, _)).Times(AnyNumber());
}

MockCommand(cl::sycl::detail::QueueImplPtr Queue,
cl::sycl::detail::Command::CommandType Type =
cl::sycl::detail::Command::RUN_CG)
: Command{Type, Queue}, MRequirement{std::move(getMockRequirement())} {}
: Command{Type, Queue}, MRequirement{std::move(getMockRequirement())} {
using namespace testing;
ON_CALL(*this, enqueue(_, _))
.WillByDefault(Invoke(this, &MockCommand::enqueueOrigin));
EXPECT_CALL(*this, enqueue(_, _)).Times(AnyNumber());
}

void printDot(std::ostream &) const override {}
void emitInstrumentationData() override {}
Expand All @@ -40,6 +52,13 @@ class MockCommand : public cl::sycl::detail::Command {

cl_int enqueueImp() override { return MRetVal; }

MOCK_METHOD2(enqueue, bool(cl::sycl::detail::EnqueueResultT &,
cl::sycl::detail::BlockingT));
bool enqueueOrigin(cl::sycl::detail::EnqueueResultT &EnqueueResult,
cl::sycl::detail::BlockingT Blocking) {
return cl::sycl::detail::Command::enqueue(EnqueueResult, Blocking);
}

cl_int MRetVal = CL_SUCCESS;

void waitForEventsCall(
Expand Down