Skip to content

Commit 8f6a793

Browse files
committed
Address comment
1 parent 5a79456 commit 8f6a793

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

src/libtorch.cc

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -628,25 +628,17 @@ ModelInstanceState::ModelInstanceState(
628628
#ifdef TRITON_ENABLE_GPU
629629
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
630630
// Only set the torch device and create a CUDA stream if the model uses GPU.
631+
for (auto it = device_id_set_.begin(); it != device_id_set_.end(); ++it) {
632+
cudaStream_t stream;
633+
THROW_IF_BACKEND_INSTANCE_ERROR(
634+
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream));
635+
stream_map_.insert({*it, stream});
636+
}
631637
if (!device_id_set_.empty()) {
632-
auto it = device_id_set_.begin();
633638
// Use the first device to create the default stream.
634-
THROW_IF_BACKEND_INSTANCE_ERROR(
635-
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream_));
639+
auto it = device_id_set_.begin();
636640
device_ = torch::Device(torch::kCUDA, *it);
637-
638-
// Create a CUDA stream for other devices so that they can be synchronized
639-
// later. Skip the first device since it is used to create the default
640-
// stream.
641-
if (it != device_id_set_.end()) {
642-
++it;
643-
}
644-
for (; it != device_id_set_.end(); ++it) {
645-
cudaStream_t stream;
646-
THROW_IF_BACKEND_INSTANCE_ERROR(
647-
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream));
648-
stream_map_.insert({*it, stream});
649-
}
641+
stream_ = stream_map_[*it];
650642
}
651643
}
652644

0 commit comments

Comments
 (0)