Skip to content

Commit dc3b92f

Browse files
committed
Address comment
1 parent a9af556 commit dc3b92f

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

src/libtorch.cc

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -627,25 +627,17 @@ ModelInstanceState::ModelInstanceState(
627627
#ifdef TRITON_ENABLE_GPU
628628
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
629629
// Only set the torch device and create a CUDA stream if the model uses GPU.
630+
for (auto it = device_id_set_.begin(); it != device_id_set_.end(); ++it) {
631+
cudaStream_t stream;
632+
THROW_IF_BACKEND_INSTANCE_ERROR(
633+
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream));
634+
stream_map_.insert({*it, stream});
635+
}
630636
if (!device_id_set_.empty()) {
631-
auto it = device_id_set_.begin();
632637
// Use the first device to create the default stream.
633-
THROW_IF_BACKEND_INSTANCE_ERROR(
634-
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream_));
638+
auto it = device_id_set_.begin();
635639
device_ = torch::Device(torch::kCUDA, *it);
636-
637-
// Create a CUDA stream for other devices so that they can be synchronized
638-
// later. Skip the first device since it is used to create the default
639-
// stream.
640-
if (it != device_id_set_.end()) {
641-
++it;
642-
}
643-
for (; it != device_id_set_.end(); ++it) {
644-
cudaStream_t stream;
645-
THROW_IF_BACKEND_INSTANCE_ERROR(
646-
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream));
647-
stream_map_.insert({*it, stream});
648-
}
640+
stream_ = stream_map_[*it];
649641
}
650642
}
651643

0 commit comments

Comments
 (0)