Skip to content

Commit ee4b1f7

Browse files
romanovvladbader
authored andcommitted
[SYCL] Refactor the way commands store requirements (#839)
1. Store requirements in all the commands. This prevents potential use after free. 2. Unified way to access stored requirement. For most of the functions that are only one requirement that needs to be stored and accessed. Being able to get it using just parent type(Command) simplifies code which sets dependency 3. Align/Fix variables names across commands Signed-off-by: Vlad Romanov <[email protected]>
1 parent 2aa2568 commit ee4b1f7

File tree

3 files changed

+207
-178
lines changed

3 files changed

+207
-178
lines changed

sycl/include/CL/sycl/detail/scheduler/commands.hpp

Lines changed: 90 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,21 @@ struct EnqueueResultT {
4848
cl_int MErrCode;
4949
};
5050

51-
5251
// DepDesc represents dependency between two commands
5352
struct DepDesc {
54-
DepDesc(Command *DepCommand, Requirement *Req, AllocaCommandBase *AllocaCmd)
55-
: MDepCommand(DepCommand), MReq(Req), MAllocaCmd(AllocaCmd) {}
53+
DepDesc(Command *DepCommand, const Requirement *Req,
54+
AllocaCommandBase *AllocaCmd)
55+
: MDepCommand(DepCommand), MDepRequirement(Req), MAllocaCmd(AllocaCmd) {}
5656

5757
friend bool operator<(const DepDesc &Lhs, const DepDesc &Rhs) {
58-
return std::tie(Lhs.MReq, Lhs.MDepCommand) <
59-
std::tie(Rhs.MReq, Rhs.MDepCommand);
58+
return std::tie(Lhs.MDepRequirement, Lhs.MDepCommand) <
59+
std::tie(Rhs.MDepRequirement, Rhs.MDepCommand);
6060
}
6161

6262
// The actual dependency command.
6363
Command *MDepCommand = nullptr;
6464
// Requirement for the dependency.
65-
Requirement *MReq = nullptr;
65+
const Requirement *MDepRequirement = nullptr;
6666
// Allocation command for the memory object we have requirement for.
6767
// Used to simplify searching for memory handle.
6868
AllocaCommandBase *MAllocaCmd = nullptr;
@@ -116,10 +116,15 @@ class Command {
116116

117117
std::shared_ptr<event_impl> getEvent() const { return MEvent; }
118118

119-
virtual ~Command() = default;
120-
121119
virtual void printDot(std::ostream &Stream) const = 0;
122120

121+
virtual const Requirement *getRequirement() const {
122+
assert(!"Internal Error. The command has no stored requirement");
123+
return nullptr;
124+
}
125+
126+
virtual ~Command() = default;
127+
123128
protected:
124129
EventImplPtr MEvent;
125130
QueueImplPtr MQueue;
@@ -129,22 +134,23 @@ class Command {
129134
RT::PiEvent &Event);
130135
std::vector<RT::PiEvent> prepareEvents(ContextImplPtr Context);
131136

132-
bool MUseExclusiveQueue = false;
133-
134137
// Private interface. Derived classes should implement this method.
135138
virtual cl_int enqueueImp() = 0;
136139

137-
public:
140+
bool MUseExclusiveQueue = false;
141+
138142
// The type of the command
139143
CommandType MType;
140144
// Indicates whether the command is enqueued or not
141145
std::atomic<bool> MEnqueued;
146+
// Mutex used to protect enqueueing from race conditions
147+
std::mutex MEnqueueMtx;
148+
149+
public:
142150
// Contains list of dependencies(edges)
143151
std::vector<DepDesc> MDeps;
144152
// Contains list of commands that depend on the command
145153
std::vector<Command *> MUsers;
146-
// Mutex used to protect enqueueing from race conditions
147-
std::mutex MEnqueueMtx;
148154
// Indicates whether the command can be blocked from enqueueing
149155
bool MIsBlockable = false;
150156
// Indicates whether the command is blocked from enqueueing
@@ -155,17 +161,17 @@ class Command {
155161
// lock in the graph, or to merge several nodes into one.
156162
class EmptyCommand : public Command {
157163
public:
158-
EmptyCommand(QueueImplPtr Queue, Requirement *Req)
164+
EmptyCommand(QueueImplPtr Queue, Requirement Req)
159165
: Command(CommandType::EMPTY_TASK, std::move(Queue)),
160-
MStoredRequirement(*Req) {}
166+
MRequirement(std::move(Req)) {}
161167

162-
Requirement *getStoredRequirement() { return &MStoredRequirement; }
168+
void printDot(std::ostream &Stream) const final;
169+
const Requirement *getRequirement() const final { return &MRequirement; }
163170

164171
private:
165-
cl_int enqueueImp() override { return CL_SUCCESS; }
166-
void printDot(std::ostream &Stream) const override;
172+
cl_int enqueueImp() final { return CL_SUCCESS; }
167173

168-
Requirement MStoredRequirement;
174+
Requirement MRequirement;
169175
};
170176

171177
// The command enqueues release instance of memory allocated on Host or
@@ -176,32 +182,34 @@ class ReleaseCommand : public Command {
176182
: Command(CommandType::RELEASE, std::move(Queue)), MAllocaCmd(AllocaCmd) {
177183
}
178184

179-
void printDot(std::ostream &Stream) const override;
185+
void printDot(std::ostream &Stream) const final;
180186

181187
private:
182-
cl_int enqueueImp() override;
188+
cl_int enqueueImp() final;
183189

190+
// Command which allocates memory release command should dealocate
184191
AllocaCommandBase *MAllocaCmd = nullptr;
185192
};
186193

187194
class AllocaCommandBase : public Command {
188195
public:
189196
AllocaCommandBase(CommandType Type, QueueImplPtr Queue, Requirement Req)
190-
: Command(Type, Queue), MReleaseCmd(Queue, this), MReq(std::move(Req)) {
191-
MReq.MAccessMode = access::mode::read_write;
197+
: Command(Type, Queue), MReleaseCmd(Queue, this),
198+
MRequirement(std::move(Req)) {
199+
MRequirement.MAccessMode = access::mode::read_write;
192200
}
193201

194202
ReleaseCommand *getReleaseCmd() { return &MReleaseCmd; }
195203

196-
SYCLMemObjI *getSYCLMemObj() const { return MReq.MSYCLMemObj; }
204+
SYCLMemObjI *getSYCLMemObj() const { return MRequirement.MSYCLMemObj; }
197205

198206
void *getMemAllocation() const { return MMemAllocation; }
199207

200-
Requirement *getAllocationReq() { return &MReq; }
208+
const Requirement *getRequirement() const final { return &MRequirement; }
201209

202210
protected:
203211
ReleaseCommand MReleaseCmd;
204-
Requirement MReq;
212+
Requirement MRequirement;
205213
void *MMemAllocation = nullptr;
206214
};
207215

@@ -211,16 +219,19 @@ class AllocaCommand : public AllocaCommandBase {
211219
public:
212220
AllocaCommand(QueueImplPtr Queue, Requirement Req,
213221
bool InitFromUserData = true)
214-
: AllocaCommandBase(CommandType::ALLOCA, std::move(Queue), Req),
222+
: AllocaCommandBase(CommandType::ALLOCA, std::move(Queue),
223+
std::move(Req)),
215224
MInitFromUserData(InitFromUserData) {
216-
addDep(DepDesc(nullptr, &MReq, this));
225+
addDep(DepDesc(nullptr, getRequirement(), this));
217226
}
218227

219-
void printDot(std::ostream &Stream) const override;
228+
void printDot(std::ostream &Stream) const final;
220229

221230
private:
222-
cl_int enqueueImp() override final;
231+
cl_int enqueueImp() final;
223232

233+
// The flag indicates that alloca should try to reuse pointer provided by the
234+
// user during memory object construction
224235
bool MInitFromUserData = false;
225236
};
226237

@@ -231,90 +242,95 @@ class AllocaSubBufCommand : public AllocaCommandBase {
231242
: AllocaCommandBase(CommandType::ALLOCA_SUB_BUF, std::move(Queue),
232243
std::move(Req)),
233244
MParentAlloca(ParentAlloca) {
234-
addDep(DepDesc(MParentAlloca, &MReq, MParentAlloca));
245+
addDep(DepDesc(MParentAlloca, getRequirement(), MParentAlloca));
235246
}
236247

237-
void printDot(std::ostream &Stream) const override;
248+
void printDot(std::ostream &Stream) const final;
238249
AllocaCommandBase *getParentAlloca() { return MParentAlloca; }
239250

240251
private:
241-
cl_int enqueueImp() override final;
252+
cl_int enqueueImp() final;
242253

243254
AllocaCommandBase *MParentAlloca;
244255
};
245256

246257
class MapMemObject : public Command {
247258
public:
248-
MapMemObject(AllocaCommandBase *SrcAlloca, Requirement *Req, void **DstPtr,
259+
MapMemObject(AllocaCommandBase *SrcAllocaCmd, Requirement Req, void **DstPtr,
249260
QueueImplPtr Queue);
250261

251-
AllocaCommandBase *MSrcAlloca = nullptr;
252-
void **MDstPtr = nullptr;
253-
Requirement MReq;
254-
255-
void printDot(std::ostream &Stream) const override;
262+
void printDot(std::ostream &Stream) const final;
263+
const Requirement *getRequirement() const final { return &MSrcReq; }
256264

257265
private:
258-
cl_int enqueueImp() override;
266+
cl_int enqueueImp() final;
267+
268+
AllocaCommandBase *MSrcAllocaCmd = nullptr;
269+
Requirement MSrcReq;
270+
void **MDstPtr = nullptr;
259271
};
260272

261273
class UnMapMemObject : public Command {
262274
public:
263-
UnMapMemObject(AllocaCommandBase *DstAlloca, Requirement *Req, void **SrcPtr,
264-
QueueImplPtr Queue, bool UseExclusiveQueue = false);
275+
UnMapMemObject(AllocaCommandBase *DstAllocaCmd, Requirement Req,
276+
void **SrcPtr, QueueImplPtr Queue,
277+
bool UseExclusiveQueue = false);
265278

266-
void printDot(std::ostream &Stream) const override;
279+
void printDot(std::ostream &Stream) const final;
280+
const Requirement *getRequirement() const final { return &MDstReq; }
267281

268282
private:
269-
cl_int enqueueImp() override;
283+
cl_int enqueueImp() final;
270284

271-
AllocaCommandBase *MDstAlloca = nullptr;
272-
Requirement MReq;
285+
AllocaCommandBase *MDstAllocaCmd = nullptr;
286+
Requirement MDstReq;
273287
void **MSrcPtr = nullptr;
274288
};
275289

276290
// The command enqueues memory copy between two instances of memory object.
277291
class MemCpyCommand : public Command {
278292
public:
279-
MemCpyCommand(Requirement SrcReq, AllocaCommandBase *SrcAlloca,
280-
Requirement DstReq, AllocaCommandBase *DstAlloca,
293+
MemCpyCommand(Requirement SrcReq, AllocaCommandBase *SrcAllocaCmd,
294+
Requirement DstReq, AllocaCommandBase *DstAllocaCmd,
281295
QueueImplPtr SrcQueue, QueueImplPtr DstQueue,
282296
bool UseExclusiveQueue = false);
283297

284-
QueueImplPtr MSrcQueue;
285-
Requirement MSrcReq;
286-
AllocaCommandBase *MSrcAlloca = nullptr;
287-
Requirement MDstReq;
288-
AllocaCommandBase *MDstAlloca = nullptr;
289-
Requirement *MAccToUpdate = nullptr;
290-
291298
void setAccessorToUpdate(Requirement *AccToUpdate) {
292299
MAccToUpdate = AccToUpdate;
293300
}
294301

295-
void printDot(std::ostream &Stream) const override;
302+
void printDot(std::ostream &Stream) const final;
303+
const Requirement *getRequirement() const final { return &MDstReq; }
296304

297305
private:
298-
cl_int enqueueImp() override;
306+
cl_int enqueueImp() final;
307+
308+
QueueImplPtr MSrcQueue;
309+
Requirement MSrcReq;
310+
AllocaCommandBase *MSrcAllocaCmd = nullptr;
311+
Requirement MDstReq;
312+
AllocaCommandBase *MDstAllocaCmd = nullptr;
313+
Requirement *MAccToUpdate = nullptr;
299314
};
300315

301316
// The command enqueues memory copy between two instances of memory object.
302317
class MemCpyCommandHost : public Command {
303318
public:
304-
MemCpyCommandHost(Requirement SrcReq, AllocaCommandBase *SrcAlloca,
319+
MemCpyCommandHost(Requirement SrcReq, AllocaCommandBase *SrcAllocaCmd,
305320
Requirement DstReq, void **DstPtr, QueueImplPtr SrcQueue,
306321
QueueImplPtr DstQueue);
307322

323+
void printDot(std::ostream &Stream) const final;
324+
const Requirement *getRequirement() const final { return &MDstReq; }
325+
326+
private:
327+
cl_int enqueueImp() final;
328+
308329
QueueImplPtr MSrcQueue;
309330
Requirement MSrcReq;
310-
AllocaCommandBase *MSrcAlloca = nullptr;
331+
AllocaCommandBase *MSrcAllocaCmd = nullptr;
311332
Requirement MDstReq;
312333
void **MDstPtr = nullptr;
313-
314-
void printDot(std::ostream &Stream) const override;
315-
316-
private:
317-
cl_int enqueueImp() override;
318334
};
319335

320336
// The command enqueues execution of kernel or explicit memory operation.
@@ -326,11 +342,10 @@ class ExecCGCommand : public Command {
326342

327343
void flushStreams();
328344

329-
void printDot(std::ostream &Stream) const override;
345+
void printDot(std::ostream &Stream) const final;
330346

331347
private:
332-
// Implementation of enqueueing of ExecCGCommand.
333-
cl_int enqueueImp() override;
348+
cl_int enqueueImp() final;
334349

335350
AllocaCommandBase *getAllocaForReq(Requirement *Req);
336351

@@ -339,20 +354,20 @@ class ExecCGCommand : public Command {
339354

340355
class UpdateHostRequirementCommand : public Command {
341356
public:
342-
UpdateHostRequirementCommand(QueueImplPtr Queue, AllocaCommandBase *AllocaCmd,
343-
Requirement *Req, void **DstPtr)
357+
UpdateHostRequirementCommand(QueueImplPtr Queue, Requirement Req,
358+
AllocaCommandBase *SrcAllocaCmd, void **DstPtr)
344359
: Command(CommandType::UPDATE_REQUIREMENT, std::move(Queue)),
345-
MDstPtr(DstPtr), MAllocaCmd(AllocaCmd), MReq(*Req) {}
360+
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(Req)), MDstPtr(DstPtr) {}
346361

347-
Requirement *getStoredRequirement() { return &MReq; }
362+
void printDot(std::ostream &Stream) const final;
363+
const Requirement *getRequirement() const final { return &MDstReq; }
348364

349365
private:
350-
cl_int enqueueImp() override;
351-
void printDot(std::ostream &Stream) const override;
366+
cl_int enqueueImp() final;
352367

368+
AllocaCommandBase *MSrcAllocaCmd = nullptr;
369+
Requirement MDstReq;
353370
void **MDstPtr = nullptr;
354-
AllocaCommandBase *MAllocaCmd = nullptr;
355-
Requirement MReq;
356371
};
357372

358373
} // namespace detail

0 commit comments

Comments
 (0)