Skip to content

Commit 7adb58f

Browse files
committed
Handle multi GPU cases when recording timestamps
1 parent 883df35 commit 7adb58f

File tree

1 file changed

+109
-16
lines changed

1 file changed

+109
-16
lines changed

src/libtorch.cc

Lines changed: 109 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class ModelState : public BackendModel {
7979
TRITONSERVER_Error* LoadModel(
8080
const std::string& artifact_name, const torch::Device device,
8181
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
82+
std::unordered_set<int>& device_id_set,
8283
std::shared_ptr<torch::jit::script::Module>* torch_model);
8384

8485
bool EnabledOptimizedExecution() { return enable_optimized_execution_; }
@@ -204,6 +205,7 @@ TRITONSERVER_Error*
204205
ModelState::LoadModel(
205206
const std::string& artifact_name, const torch::Device device,
206207
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
208+
std::unordered_set<int>& device_id_set,
207209
std::shared_ptr<torch::jit::script::Module>* torch_model)
208210
{
209211
// Find the TorchScript file that describes the model. If the model
@@ -254,9 +256,23 @@ ModelState::LoadModel(
254256
try {
255257
std::istringstream model_stream(model_data_str);
256258
if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
257-
// Don't select the device when loading the model.
258259
torch_model->reset(
259260
new torch::jit::Module(torch::jit::load(model_stream)));
261+
262+
// Get the device used in the model
263+
auto parameters = (*torch_model)->parameters();
264+
auto buffers = (*torch_model)->buffers();
265+
266+
for (const auto& parameter : parameters) {
267+
if (parameter.device().type() != torch::kCPU) {
268+
device_id_set.insert(parameter.device().index());
269+
}
270+
}
271+
for (const auto& buffer : buffers) {
272+
if (buffer.device().type() != torch::kCPU) {
273+
device_id_set.insert(buffer.device().index());
274+
}
275+
}
260276
} else {
261277
torch_model->reset(
262278
new torch::jit::Module(torch::jit::load(model_stream, device)));
@@ -559,6 +575,13 @@ class ModelInstanceState : public BackendModelInstance {
559575
cudaEvent_t compute_input_start_event_;
560576
cudaEvent_t compute_infer_start_event_;
561577
cudaEvent_t compute_output_start_event_;
578+
579+
// Store the GPU device ID used in a model for the instance group of type'
580+
// MODEL'.
581+
std::unordered_set<int> device_id_set_;
582+
// Store the extra cuda stream created for the instance group of type' MODEL'
583+
// and use device ID as the key.
584+
std::unordered_map<int, cudaStream_t> stream_map_;
562585
};
563586

564587
TRITONSERVER_Error*
@@ -587,10 +610,43 @@ ModelInstanceState::ModelInstanceState(
587610
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
588611
#ifdef TRITON_ENABLE_GPU
589612
device_ = torch::Device(torch::kCUDA, DeviceId());
613+
#endif
614+
}
615+
616+
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
617+
ArtifactFilename(), device_, &model_path_, Kind(), device_id_set_,
618+
&torch_model_));
619+
620+
#ifdef TRITON_ENABLE_GPU
621+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
622+
// Only set the torch device and create a CUDA stream if the model uses GPU.
623+
if (!device_id_set_.empty()) {
624+
auto it = device_id_set_.begin();
625+
// Use the first device to create the default stream.
626+
THROW_IF_BACKEND_INSTANCE_ERROR(
627+
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream_));
628+
device_ = torch::Device(torch::kCUDA, *it);
629+
630+
// Create a CUDA stream for other devices so that they can be synchronized
631+
// later. Skip the first device since it is used to create the default
632+
// stream.
633+
if (it != device_id_set_.end()) {
634+
++it;
635+
}
636+
for (; it != device_id_set_.end(); ++it) {
637+
cudaStream_t stream;
638+
THROW_IF_BACKEND_INSTANCE_ERROR(
639+
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream));
640+
stream_map_.insert({*it, stream});
641+
}
642+
}
643+
}
644+
645+
if (device_.is_cuda()) {
590646
// Need to set the CUDA context so that the context that events are
591647
// created on match with contexts that events are recorded with.
592648
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
593-
cudaSetDevice(DeviceId()), TRITONSERVER_ERROR_INTERNAL,
649+
cudaSetDevice(device_.index()), TRITONSERVER_ERROR_INTERNAL,
594650
"Failed to set the device"));
595651
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
596652
cudaEventCreate(&compute_input_start_event_),
@@ -601,11 +657,8 @@ ModelInstanceState::ModelInstanceState(
601657
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
602658
cudaEventCreate(&compute_output_start_event_),
603659
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
604-
#endif
605660
}
606-
607-
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
608-
ArtifactFilename(), device_, &model_path_, Kind(), &torch_model_));
661+
#endif
609662

610663
size_t expected_input_cnt = 0;
611664
{
@@ -667,6 +720,21 @@ ModelInstanceState::~ModelInstanceState()
667720
{
668721
torch_model_.reset();
669722
ClearCache();
723+
724+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
725+
for (auto& m : stream_map_) {
726+
cudaSetDevice(m.first);
727+
cudaError_t err = cudaStreamDestroy(m.second);
728+
if (err != cudaSuccess) {
729+
TRITONSERVER_LogMessage(
730+
TRITONSERVER_LOG_ERROR, __FILE__, __LINE__,
731+
(std::string("~ModelInstanceState: ") + name_ +
732+
" failed to destroy cuda stream: " + cudaGetErrorString(err))
733+
.c_str());
734+
}
735+
m.second = nullptr;
736+
}
737+
}
670738
}
671739

672740
TRITONSERVER_Error*
@@ -1006,13 +1074,16 @@ ModelInstanceState::ProcessRequests(
10061074
std::to_string(request_count) + " requests")
10071075
.c_str());
10081076

1009-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
10101077
#ifdef TRITON_ENABLE_GPU
1011-
at::cuda::CUDAStream torch_stream =
1012-
at::cuda::getStreamFromExternal(stream_, DeviceId());
1078+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1079+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && device_.is_cuda())) {
1080+
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromExternal(
1081+
stream_, (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU)
1082+
? DeviceId()
1083+
: device_.index());
10131084
at::cuda::setCurrentCUDAStream(torch_stream);
1014-
#endif
10151085
}
1086+
#endif
10161087

10171088
NVTX_RANGE(nvtx_, "ProcessRequests " + Name());
10181089

@@ -1118,7 +1189,8 @@ ModelInstanceState::ProcessRequests(
11181189
std::vector<torch::jit::IValue> input_tensors;
11191190
bool cuda_copy = false;
11201191
std::unique_ptr<BackendInputCollector> collector;
1121-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1192+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1193+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr)) {
11221194
#ifdef TRITON_ENABLE_GPU
11231195
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
11241196
responses, request_count, all_response_failed,
@@ -1143,6 +1215,11 @@ ModelInstanceState::ProcessRequests(
11431215
#ifdef TRITON_ENABLE_GPU
11441216
if (cuda_copy) {
11451217
cudaStreamSynchronize(stream_);
1218+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1219+
for (auto& m : stream_map_) {
1220+
cudaStreamSynchronize(m.second);
1221+
}
1222+
}
11461223
cuda_copy = false;
11471224
}
11481225
#endif
@@ -1220,7 +1297,8 @@ ModelInstanceState::ProcessRequests(
12201297

12211298
// We don't need an explicit CUDA syncrhonization here since we have already
12221299
// synchronized the stream in the ReadOutputTensors function.
1223-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1300+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1301+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr)) {
12241302
#ifdef TRITON_ENABLE_GPU
12251303
// [FIXME] in the case of cudaEventElapsedTime failure, should handle
12261304
// stats reporting more gracefully as the durations are inaccurate
@@ -1574,7 +1652,9 @@ SetStringInputTensor(
15741652
torch::List<std::string>* input_list, TRITONBACKEND_Input* input,
15751653
const char* name, const uint32_t buffer_count,
15761654
const size_t request_element_cnt, TRITONBACKEND_Response** response,
1577-
cudaStream_t stream, const char* host_policy_name)
1655+
cudaStream_t stream,
1656+
const std::unordered_map<int, cudaStream_t>& stream_map,
1657+
const char* host_policy_name, const TRITONSERVER_InstanceGroupKind& kind)
15781658
{
15791659
bool cuda_copy = false;
15801660
size_t element_idx = 0;
@@ -1599,6 +1679,11 @@ SetStringInputTensor(
15991679
#ifdef TRITON_ENABLE_GPU
16001680
if (cuda_copy) {
16011681
cudaStreamSynchronize(stream);
1682+
if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1683+
for (auto& m : stream_map) {
1684+
cudaStreamSynchronize(m.second);
1685+
}
1686+
}
16021687
cuda_copy = false;
16031688
}
16041689
#endif // TRITON_ENABLE_GPU
@@ -1777,7 +1862,8 @@ ModelInstanceState::SetInputTensors(
17771862

17781863
// The input must be in contiguous CPU/GPU memory.
17791864
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference;
1780-
if (device_.is_cpu()) {
1865+
if ((device_.is_cpu()) ||
1866+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL)) {
17811867
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
17821868
{TRITONSERVER_MEMORY_CPU, 0}};
17831869
} else {
@@ -1822,7 +1908,8 @@ ModelInstanceState::SetInputTensors(
18221908

18231909
*cuda_copy |= SetStringInputTensor(
18241910
&input_list, input, input_name, buffer_count, batch_element_cnt,
1825-
&((*responses)[idx]), CudaStream(), HostPolicyName().c_str());
1911+
&((*responses)[idx]), CudaStream(), stream_map_,
1912+
HostPolicyName().c_str(), Kind());
18261913
}
18271914

18281915
(*input_tensors)[input_index_map_[input_name]] = input_list;
@@ -1980,6 +2067,11 @@ ModelInstanceState::ReadOutputTensors(
19802067
// are only guaranteed to be synchronized if the model provides the output
19812068
// on GPU.
19822069
cudaStreamSynchronize(stream_);
2070+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
2071+
for (auto& m : stream_map_) {
2072+
cudaStreamSynchronize(m.second);
2073+
}
2074+
}
19832075
#endif
19842076

19852077
return nullptr;
@@ -1989,7 +2081,8 @@ TRITONSERVER_Error*
19892081
ModelInstanceState::RecordBackendTimestamp(
19902082
uint64_t* timestamp, void* cuda_event)
19912083
{
1992-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
2084+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
2085+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr)) {
19932086
#ifdef TRITON_ENABLE_GPU
19942087
cudaEvent_t* lcuda_event = reinterpret_cast<cudaEvent_t*>(cuda_event);
19952088
RETURN_IF_ERROR(ConvertCUDAStatusToTritonError(

0 commit comments

Comments
 (0)