Skip to content

Commit 0cd42e0

Browse files
committed
Address comment
1 parent 7adb58f commit 0cd42e0

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

src/libtorch.cc

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -620,25 +620,17 @@ ModelInstanceState::ModelInstanceState(
620620
#ifdef TRITON_ENABLE_GPU
621621
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
622622
// Only set the torch device and create a CUDA stream if the model uses GPU.
623+
for (auto it = device_id_set_.begin(); it != device_id_set_.end(); ++it) {
624+
cudaStream_t stream;
625+
THROW_IF_BACKEND_INSTANCE_ERROR(
626+
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream));
627+
stream_map_.insert({*it, stream});
628+
}
623629
if (!device_id_set_.empty()) {
624-
auto it = device_id_set_.begin();
625630
// Use the first device to create the default stream.
626-
THROW_IF_BACKEND_INSTANCE_ERROR(
627-
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream_));
631+
auto it = device_id_set_.begin();
628632
device_ = torch::Device(torch::kCUDA, *it);
629-
630-
// Create a CUDA stream for other devices so that they can be synchronized
631-
// later. Skip the first device since it is used to create the default
632-
// stream.
633-
if (it != device_id_set_.end()) {
634-
++it;
635-
}
636-
for (; it != device_id_set_.end(); ++it) {
637-
cudaStream_t stream;
638-
THROW_IF_BACKEND_INSTANCE_ERROR(
639-
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream));
640-
stream_map_.insert({*it, stream});
641-
}
633+
stream_ = stream_map_[*it];
642634
}
643635
}
644636

0 commit comments

Comments
 (0)