@@ -391,6 +391,8 @@ struct _pi_queue {
391
391
std::atomic_uint32_t transfer_stream_idx_;
392
392
unsigned int num_compute_streams_;
393
393
unsigned int num_transfer_streams_;
394
+ unsigned int last_sync_compute_streams_;
395
+ unsigned int last_sync_transfer_streams_;
394
396
unsigned int flags_;
395
397
std::mutex compute_stream_mutex_;
396
398
std::mutex transfer_stream_mutex_;
@@ -403,7 +405,9 @@ struct _pi_queue {
403
405
transfer_streams_{std::move (transfer_streams)}, context_{context},
404
406
device_{device}, properties_{properties}, refCount_{1 }, eventCount_{0 },
405
407
compute_stream_idx_{0 }, transfer_stream_idx_{0 },
406
- num_compute_streams_{0 }, num_transfer_streams_{0 }, flags_(flags) {
408
+ num_compute_streams_{0 }, num_transfer_streams_{0 },
409
+ last_sync_compute_streams_{0 }, last_sync_transfer_streams_{0 },
410
+ flags_ (flags) {
407
411
cuda_piContextRetain (context_);
408
412
cuda_piDeviceRetain (device_);
409
413
}
@@ -440,6 +444,59 @@ struct _pi_queue {
440
444
}
441
445
}
442
446
447
+ template <typename T> void sync_streams (T &&f) {
448
+ auto sync = [&f](const std::vector<CUstream> &streams, unsigned int start,
449
+ unsigned int stop) {
450
+ for (unsigned int i = start; i < stop; i++) {
451
+ f (streams[i]);
452
+ }
453
+ };
454
+ {
455
+ unsigned int size = static_cast <unsigned int >(compute_streams_.size ());
456
+ std::lock_guard<std::mutex> compute_guard (compute_stream_mutex_);
457
+ unsigned int start = last_sync_compute_streams_;
458
+ unsigned int end = num_compute_streams_ < size
459
+ ? num_compute_streams_
460
+ : compute_stream_idx_.load ();
461
+ last_sync_compute_streams_ = end;
462
+ if (end - start >= size) {
463
+ sync (compute_streams_, 0 , size);
464
+ } else {
465
+ start %= size;
466
+ end %= size;
467
+ if (start < end) {
468
+ sync (compute_streams_, start, end);
469
+ } else {
470
+ sync (compute_streams_, start, size);
471
+ sync (compute_streams_, 0 , end);
472
+ }
473
+ }
474
+ }
475
+ {
476
+ unsigned int size = static_cast <unsigned int >(transfer_streams_.size ());
477
+ if (size > 0 ) {
478
+ std::lock_guard<std::mutex> transfer_guard (transfer_stream_mutex_);
479
+ unsigned int start = last_sync_transfer_streams_;
480
+ unsigned int end = num_transfer_streams_ < size
481
+ ? num_transfer_streams_
482
+ : transfer_stream_idx_.load ();
483
+ last_sync_transfer_streams_ = end;
484
+ if (end - start >= size) {
485
+ sync (transfer_streams_, 0 , size);
486
+ } else {
487
+ start %= size;
488
+ end %= size;
489
+ if (start < end) {
490
+ sync (transfer_streams_, start, end);
491
+ } else {
492
+ sync (transfer_streams_, start, size);
493
+ sync (transfer_streams_, 0 , end);
494
+ }
495
+ }
496
+ }
497
+ }
498
+ }
499
+
443
500
_pi_context *get_context () const { return context_; };
444
501
445
502
_pi_device *get_device () const { return device_; };
0 commit comments