Skip to content

Commit 5a79456

Browse files
committed
Handle multi GPU cases when recording timestamps
1 parent 3efa323 commit 5a79456

File tree

1 file changed

+109
-16
lines changed

1 file changed

+109
-16
lines changed

src/libtorch.cc

Lines changed: 109 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class ModelState : public BackendModel {
8282
TRITONSERVER_Error* LoadModel(
8383
const std::string& artifact_name, const torch::Device device,
8484
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
85+
std::unordered_set<int>& device_id_set,
8586
std::shared_ptr<torch::jit::script::Module>* torch_model);
8687

8788
bool EnabledOptimizedExecution() { return enable_optimized_execution_; }
@@ -207,6 +208,7 @@ TRITONSERVER_Error*
207208
ModelState::LoadModel(
208209
const std::string& artifact_name, const torch::Device device,
209210
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
211+
std::unordered_set<int>& device_id_set,
210212
std::shared_ptr<torch::jit::script::Module>* torch_model)
211213
{
212214
// Find the TorchScript file that describes the model. If the model
@@ -257,9 +259,23 @@ ModelState::LoadModel(
257259
try {
258260
std::istringstream model_stream(model_data_str);
259261
if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
260-
// Don't select the device when loading the model.
261262
torch_model->reset(
262263
new torch::jit::Module(torch::jit::load(model_stream)));
264+
265+
// Get the device used in the model
266+
auto parameters = (*torch_model)->parameters();
267+
auto buffers = (*torch_model)->buffers();
268+
269+
for (const auto& parameter : parameters) {
270+
if (parameter.device().type() != torch::kCPU) {
271+
device_id_set.insert(parameter.device().index());
272+
}
273+
}
274+
for (const auto& buffer : buffers) {
275+
if (buffer.device().type() != torch::kCPU) {
276+
device_id_set.insert(buffer.device().index());
277+
}
278+
}
263279
} else {
264280
torch_model->reset(
265281
new torch::jit::Module(torch::jit::load(model_stream, device)));
@@ -567,6 +583,13 @@ class ModelInstanceState : public BackendModelInstance {
567583
cudaEvent_t compute_input_start_event_;
568584
cudaEvent_t compute_infer_start_event_;
569585
cudaEvent_t compute_output_start_event_;
586+
587+
// Store the GPU device ID used in a model for the instance group of type'
588+
// MODEL'.
589+
std::unordered_set<int> device_id_set_;
590+
// Store the extra cuda stream created for the instance group of type' MODEL'
591+
// and use device ID as the key.
592+
std::unordered_map<int, cudaStream_t> stream_map_;
570593
};
571594

572595
TRITONSERVER_Error*
@@ -595,10 +618,43 @@ ModelInstanceState::ModelInstanceState(
595618
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
596619
#ifdef TRITON_ENABLE_GPU
597620
device_ = torch::Device(torch::kCUDA, DeviceId());
621+
#endif
622+
}
623+
624+
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
625+
ArtifactFilename(), device_, &model_path_, Kind(), device_id_set_,
626+
&torch_model_));
627+
628+
#ifdef TRITON_ENABLE_GPU
629+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
630+
// Only set the torch device and create a CUDA stream if the model uses GPU.
631+
if (!device_id_set_.empty()) {
632+
auto it = device_id_set_.begin();
633+
// Use the first device to create the default stream.
634+
THROW_IF_BACKEND_INSTANCE_ERROR(
635+
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream_));
636+
device_ = torch::Device(torch::kCUDA, *it);
637+
638+
// Create a CUDA stream for other devices so that they can be synchronized
639+
// later. Skip the first device since it is used to create the default
640+
// stream.
641+
if (it != device_id_set_.end()) {
642+
++it;
643+
}
644+
for (; it != device_id_set_.end(); ++it) {
645+
cudaStream_t stream;
646+
THROW_IF_BACKEND_INSTANCE_ERROR(
647+
CreateCudaStream(*it, 0 /* cuda_stream_priority */, &stream));
648+
stream_map_.insert({*it, stream});
649+
}
650+
}
651+
}
652+
653+
if (device_.is_cuda()) {
598654
// Need to set the CUDA context so that the context that events are
599655
// created on match with contexts that events are recorded with.
600656
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
601-
cudaSetDevice(DeviceId()), TRITONSERVER_ERROR_INTERNAL,
657+
cudaSetDevice(device_.index()), TRITONSERVER_ERROR_INTERNAL,
602658
"Failed to set the device"));
603659
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
604660
cudaEventCreate(&compute_input_start_event_),
@@ -609,11 +665,8 @@ ModelInstanceState::ModelInstanceState(
609665
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
610666
cudaEventCreate(&compute_output_start_event_),
611667
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
612-
#endif
613668
}
614-
615-
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
616-
ArtifactFilename(), device_, &model_path_, Kind(), &torch_model_));
669+
#endif
617670

618671
size_t expected_input_cnt = 0;
619672
{
@@ -681,6 +734,21 @@ ModelInstanceState::~ModelInstanceState()
681734
{
682735
torch_model_.reset();
683736
ClearCache();
737+
738+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
739+
for (auto& m : stream_map_) {
740+
cudaSetDevice(m.first);
741+
cudaError_t err = cudaStreamDestroy(m.second);
742+
if (err != cudaSuccess) {
743+
TRITONSERVER_LogMessage(
744+
TRITONSERVER_LOG_ERROR, __FILE__, __LINE__,
745+
(std::string("~ModelInstanceState: ") + name_ +
746+
" failed to destroy cuda stream: " + cudaGetErrorString(err))
747+
.c_str());
748+
}
749+
m.second = nullptr;
750+
}
751+
}
684752
}
685753

686754
TRITONSERVER_Error*
@@ -1040,13 +1108,16 @@ ModelInstanceState::ProcessRequests(
10401108
std::to_string(request_count) + " requests")
10411109
.c_str());
10421110

1043-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
10441111
#ifdef TRITON_ENABLE_GPU
1045-
at::cuda::CUDAStream torch_stream =
1046-
at::cuda::getStreamFromExternal(stream_, DeviceId());
1112+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1113+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && device_.is_cuda())) {
1114+
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromExternal(
1115+
stream_, (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU)
1116+
? DeviceId()
1117+
: device_.index());
10471118
at::cuda::setCurrentCUDAStream(torch_stream);
1048-
#endif
10491119
}
1120+
#endif
10501121

10511122
NVTX_RANGE(nvtx_, "ProcessRequests " + Name());
10521123

@@ -1152,7 +1223,8 @@ ModelInstanceState::ProcessRequests(
11521223
std::vector<torch::jit::IValue> input_tensors;
11531224
bool cuda_copy = false;
11541225
std::unique_ptr<BackendInputCollector> collector;
1155-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1226+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1227+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr)) {
11561228
#ifdef TRITON_ENABLE_GPU
11571229
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
11581230
responses, request_count, all_response_failed,
@@ -1177,6 +1249,11 @@ ModelInstanceState::ProcessRequests(
11771249
#ifdef TRITON_ENABLE_GPU
11781250
if (cuda_copy) {
11791251
cudaStreamSynchronize(stream_);
1252+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1253+
for (auto& m : stream_map_) {
1254+
cudaStreamSynchronize(m.second);
1255+
}
1256+
}
11801257
cuda_copy = false;
11811258
}
11821259
#endif
@@ -1254,7 +1331,8 @@ ModelInstanceState::ProcessRequests(
12541331

12551332
// We don't need an explicit CUDA syncrhonization here since we have already
12561333
// synchronized the stream in the ReadOutputTensors function.
1257-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1334+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1335+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr)) {
12581336
#ifdef TRITON_ENABLE_GPU
12591337
// [FIXME] in the case of cudaEventElapsedTime failure, should handle
12601338
// stats reporting more gracefully as the durations are inaccurate
@@ -1608,7 +1686,9 @@ SetStringInputTensor(
16081686
torch::List<std::string>* input_list, TRITONBACKEND_Input* input,
16091687
const char* name, const uint32_t buffer_count,
16101688
const size_t request_element_cnt, TRITONBACKEND_Response** response,
1611-
cudaStream_t stream, const char* host_policy_name)
1689+
cudaStream_t stream,
1690+
const std::unordered_map<int, cudaStream_t>& stream_map,
1691+
const char* host_policy_name, const TRITONSERVER_InstanceGroupKind& kind)
16121692
{
16131693
bool cuda_copy = false;
16141694
size_t element_idx = 0;
@@ -1633,6 +1713,11 @@ SetStringInputTensor(
16331713
#ifdef TRITON_ENABLE_GPU
16341714
if (cuda_copy) {
16351715
cudaStreamSynchronize(stream);
1716+
if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1717+
for (auto& m : stream_map) {
1718+
cudaStreamSynchronize(m.second);
1719+
}
1720+
}
16361721
cuda_copy = false;
16371722
}
16381723
#endif // TRITON_ENABLE_GPU
@@ -1812,7 +1897,8 @@ ModelInstanceState::SetInputTensors(
18121897

18131898
// The input must be in contiguous CPU/GPU memory.
18141899
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference;
1815-
if (device_.is_cpu()) {
1900+
if ((device_.is_cpu()) ||
1901+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL)) {
18161902
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
18171903
{TRITONSERVER_MEMORY_CPU, 0}};
18181904
} else {
@@ -1857,7 +1943,8 @@ ModelInstanceState::SetInputTensors(
18571943

18581944
*cuda_copy |= SetStringInputTensor(
18591945
&input_list, input, input_name, buffer_count, batch_element_cnt,
1860-
&((*responses)[idx]), CudaStream(), HostPolicyName().c_str());
1946+
&((*responses)[idx]), CudaStream(), stream_map_,
1947+
HostPolicyName().c_str(), Kind());
18611948
}
18621949

18631950
(*input_tensors)[input_index_map_[input_name]] = input_list;
@@ -2045,6 +2132,11 @@ ModelInstanceState::ReadOutputTensors(
20452132
// are only guaranteed to be synchronized if the model provides the output
20462133
// on GPU.
20472134
cudaStreamSynchronize(stream_);
2135+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
2136+
for (auto& m : stream_map_) {
2137+
cudaStreamSynchronize(m.second);
2138+
}
2139+
}
20482140
#endif
20492141

20502142
return nullptr;
@@ -2054,7 +2146,8 @@ TRITONSERVER_Error*
20542146
ModelInstanceState::RecordBackendTimestamp(
20552147
uint64_t* timestamp, void* cuda_event)
20562148
{
2057-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
2149+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
2150+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr)) {
20582151
#ifdef TRITON_ENABLE_GPU
20592152
cudaEvent_t* lcuda_event = reinterpret_cast<cudaEvent_t*>(cuda_event);
20602153
RETURN_IF_ERROR(ConvertCUDAStatusToTritonError(

0 commit comments

Comments
 (0)