Skip to content

Commit a9af556

Browse files
committed
Handle multi GPU cases when recording timestamps
1 parent 05ae43a commit a9af556

File tree

1 file changed

+117
-15
lines changed

1 file changed

+117
-15
lines changed

src/libtorch.cc

Lines changed: 117 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class ModelState : public BackendModel {
8181
TRITONSERVER_Error* LoadModel(
8282
const std::string& artifact_name, const torch::Device device,
8383
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
84+
std::unordered_set<int>& device_id_set,
8485
std::shared_ptr<torch::jit::script::Module>* torch_model);
8586

8687
bool EnabledOptimizedExecution() { return enable_optimized_execution_; }
@@ -206,6 +207,7 @@ TRITONSERVER_Error*
206207
ModelState::LoadModel(
207208
const std::string& artifact_name, const torch::Device device,
208209
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
210+
std::unordered_set<int>& device_id_set,
209211
std::shared_ptr<torch::jit::script::Module>* torch_model)
210212
{
211213
// Find the TorchScript file that describes the model. If the model
@@ -256,9 +258,23 @@ ModelState::LoadModel(
256258
try {
257259
std::istringstream model_stream(model_data_str);
258260
if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
259-
// Don't select the device when loading the model.
260261
torch_model->reset(
261262
new torch::jit::Module(torch::jit::load(model_stream)));
263+
264+
// Get the device used in the model
265+
auto parameters = (*torch_model)->parameters();
266+
auto buffers = (*torch_model)->buffers();
267+
268+
for (const auto& parameter : parameters) {
269+
if (parameter.device().type() != torch::kCPU) {
270+
device_id_set.insert(parameter.device().index());
271+
}
272+
}
273+
for (const auto& buffer : buffers) {
274+
if (buffer.device().type() != torch::kCPU) {
275+
device_id_set.insert(buffer.device().index());
276+
}
277+
}
262278
} else {
263279
torch_model->reset(
264280
new torch::jit::Module(torch::jit::load(model_stream, device)));
@@ -566,6 +582,13 @@ class ModelInstanceState : public BackendModelInstance {
566582
cudaEvent_t compute_input_start_event_;
567583
cudaEvent_t compute_infer_start_event_;
568584
cudaEvent_t compute_output_start_event_;
585+
586+
// Store the GPU device ID used in a model for the instance group of type'
587+
// MODEL'.
588+
std::unordered_set<int> device_id_set_;
589+
// Store the extra cuda stream created for the instance group of type' MODEL'
590+
// and use device ID as the key.
591+
std::unordered_map<int, cudaStream_t> stream_map_;
569592
};
570593

571594
TRITONSERVER_Error*
@@ -594,10 +617,43 @@ ModelInstanceState::ModelInstanceState(
594617
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
595618
#ifdef TRITON_ENABLE_GPU
596619
device_ = torch::Device(torch::kCUDA, DeviceId());
620+
#endif
621+
}
622+
623+
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
624+
ArtifactFilename(), device_, &model_path_, Kind(), device_id_set_,
625+
&torch_model_));
626+
627+
#ifdef TRITON_ENABLE_GPU
628+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
629+
// Only set the torch device and create a CUDA stream if the model uses GPU.
630+
if (!device_id_set_.empty()) {
631+
auto it = device_id_set_.begin();
632+
// Use the first device to create the default stream.
633+
THROW_IF_BACKEND_INSTANCE_ERROR(
634+
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream_));
635+
device_ = torch::Device(torch::kCUDA, *it);
636+
637+
// Create a CUDA stream for other devices so that they can be synchronized
638+
// later. Skip the first device since it is used to create the default
639+
// stream.
640+
if (it != device_id_set_.end()) {
641+
++it;
642+
}
643+
for (; it != device_id_set_.end(); ++it) {
644+
cudaStream_t stream;
645+
THROW_IF_BACKEND_INSTANCE_ERROR(
646+
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream));
647+
stream_map_.insert({*it, stream});
648+
}
649+
}
650+
}
651+
652+
if (device_.is_cuda()) {
597653
// Need to set the CUDA context so that the context that events are
598654
// created on match with contexts that events are recorded with.
599655
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
600-
cudaSetDevice(DeviceId()), TRITONSERVER_ERROR_INTERNAL,
656+
cudaSetDevice(device_.index()), TRITONSERVER_ERROR_INTERNAL,
601657
"Failed to set the device"));
602658
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
603659
cudaEventCreate(&compute_input_start_event_),
@@ -608,11 +664,8 @@ ModelInstanceState::ModelInstanceState(
608664
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
609665
cudaEventCreate(&compute_output_start_event_),
610666
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
611-
#endif
612667
}
613-
614-
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
615-
ArtifactFilename(), device_, &model_path_, Kind(), &torch_model_));
668+
#endif
616669

617670
size_t expected_input_cnt = 0;
618671
{
@@ -680,6 +733,21 @@ ModelInstanceState::~ModelInstanceState()
680733
{
681734
torch_model_.reset();
682735
ClearCache();
736+
737+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
738+
for (auto& m : stream_map_) {
739+
cudaSetDevice(m.first);
740+
cudaError_t err = cudaStreamDestroy(m.second);
741+
if (err != cudaSuccess) {
742+
TRITONSERVER_LogMessage(
743+
TRITONSERVER_LOG_ERROR, __FILE__, __LINE__,
744+
(std::string("~ModelInstanceState: ") + name_ +
745+
" failed to destroy cuda stream: " + cudaGetErrorString(err))
746+
.c_str());
747+
}
748+
m.second = nullptr;
749+
}
750+
}
683751
}
684752

685753
TRITONSERVER_Error*
@@ -1039,13 +1107,16 @@ ModelInstanceState::ProcessRequests(
10391107
std::to_string(request_count) + " requests")
10401108
.c_str());
10411109

1042-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
10431110
#ifdef TRITON_ENABLE_GPU
1044-
at::cuda::CUDAStream torch_stream =
1045-
at::cuda::getStreamFromExternal(stream_, DeviceId());
1111+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1112+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && device_.is_cuda())) {
1113+
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromExternal(
1114+
stream_, (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU)
1115+
? DeviceId()
1116+
: device_.index());
10461117
at::cuda::setCurrentCUDAStream(torch_stream);
1047-
#endif
10481118
}
1119+
#endif
10491120

10501121
NVTX_RANGE(nvtx_, "ProcessRequests " + Name());
10511122

@@ -1151,7 +1222,8 @@ ModelInstanceState::ProcessRequests(
11511222
std::vector<torch::jit::IValue> input_tensors;
11521223
bool cuda_copy = false;
11531224
std::unique_ptr<BackendInputCollector> collector;
1154-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1225+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1226+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr)) {
11551227
#ifdef TRITON_ENABLE_GPU
11561228
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
11571229
responses, request_count, all_response_failed,
@@ -1176,6 +1248,11 @@ ModelInstanceState::ProcessRequests(
11761248
#ifdef TRITON_ENABLE_GPU
11771249
if (cuda_copy) {
11781250
cudaStreamSynchronize(stream_);
1251+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1252+
for (auto& m : stream_map_) {
1253+
cudaStreamSynchronize(m.second);
1254+
}
1255+
}
11791256
cuda_copy = false;
11801257
}
11811258
#endif
@@ -1253,7 +1330,8 @@ ModelInstanceState::ProcessRequests(
12531330

12541331
// We don't need an explicit CUDA syncrhonization here since we have already
12551332
// synchronized the stream in the ReadOutputTensors function.
1256-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1333+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1334+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr)) {
12571335
#ifdef TRITON_ENABLE_GPU
12581336
// [FIXME] in the case of cudaEventElapsedTime failure, should handle
12591337
// stats reporting more gracefully as the durations are inaccurate
@@ -1607,7 +1685,9 @@ SetStringInputTensor(
16071685
torch::List<std::string>* input_list, TRITONBACKEND_Input* input,
16081686
const char* name, const uint32_t buffer_count,
16091687
const size_t request_element_cnt, TRITONBACKEND_Response** response,
1610-
cudaStream_t stream, const char* host_policy_name)
1688+
cudaStream_t stream,
1689+
const std::unordered_map<int, cudaStream_t>& stream_map,
1690+
const char* host_policy_name, const TRITONSERVER_InstanceGroupKind& kind)
16111691
{
16121692
bool cuda_copy = false;
16131693
size_t element_idx = 0;
@@ -1632,6 +1712,11 @@ SetStringInputTensor(
16321712
#ifdef TRITON_ENABLE_GPU
16331713
if (cuda_copy) {
16341714
cudaStreamSynchronize(stream);
1715+
if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1716+
for (auto& m : stream_map) {
1717+
cudaStreamSynchronize(m.second);
1718+
}
1719+
}
16351720
cuda_copy = false;
16361721
}
16371722
#endif // TRITON_ENABLE_GPU
@@ -1819,6 +1904,16 @@ ModelInstanceState::SetInputTensors(
18191904
}
18201905
}
18211906

1907+
// The input must be in contiguous CPU/GPU memory.
1908+
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference;
1909+
if ((device_.is_cpu()) ||
1910+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL)) {
1911+
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
1912+
{TRITONSERVER_MEMORY_CPU, 0}};
1913+
} else {
1914+
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}};
1915+
}
1916+
18221917
const char* input_buffer;
18231918
size_t batchn_byte_size;
18241919
TRITONSERVER_MemoryType memory_type;
@@ -1857,7 +1952,8 @@ ModelInstanceState::SetInputTensors(
18571952

18581953
*cuda_copy |= SetStringInputTensor(
18591954
&input_list, input, input_name, buffer_count, batch_element_cnt,
1860-
&((*responses)[idx]), CudaStream(), HostPolicyName().c_str());
1955+
&((*responses)[idx]), CudaStream(), stream_map_,
1956+
HostPolicyName().c_str(), Kind());
18611957
}
18621958

18631959
(*input_tensors)[input_index_map_[input_name]] = input_list;
@@ -2046,6 +2142,11 @@ ModelInstanceState::ReadOutputTensors(
20462142
// are only guaranteed to be synchronized if the model provides the output
20472143
// on GPU.
20482144
cudaStreamSynchronize(stream_);
2145+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
2146+
for (auto& m : stream_map_) {
2147+
cudaStreamSynchronize(m.second);
2148+
}
2149+
}
20492150
#endif
20502151

20512152
return nullptr;
@@ -2055,7 +2156,8 @@ TRITONSERVER_Error*
20552156
ModelInstanceState::RecordBackendTimestamp(
20562157
uint64_t* timestamp, void* cuda_event)
20572158
{
2058-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
2159+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
2160+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr)) {
20592161
#ifdef TRITON_ENABLE_GPU
20602162
cudaEvent_t* lcuda_event = reinterpret_cast<cudaEvent_t*>(cuda_event);
20612163
RETURN_IF_ERROR(ConvertCUDAStatusToTritonError(

0 commit comments

Comments
 (0)