@@ -81,6 +81,7 @@ class ModelState : public BackendModel {
81
81
TRITONSERVER_Error* LoadModel (
82
82
const std::string& artifact_name, const torch::Device device,
83
83
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
84
+ std::unordered_set<int >& device_id_set,
84
85
std::shared_ptr<torch::jit::script::Module>* torch_model);
85
86
86
87
bool EnabledOptimizedExecution () { return enable_optimized_execution_; }
@@ -206,6 +207,7 @@ TRITONSERVER_Error*
206
207
ModelState::LoadModel (
207
208
const std::string& artifact_name, const torch::Device device,
208
209
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
210
+ std::unordered_set<int >& device_id_set,
209
211
std::shared_ptr<torch::jit::script::Module>* torch_model)
210
212
{
211
213
// Find the TorchScript file that describes the model. If the model
@@ -256,9 +258,23 @@ ModelState::LoadModel(
256
258
try {
257
259
std::istringstream model_stream (model_data_str);
258
260
if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
259
- // Don't select the device when loading the model.
260
261
torch_model->reset (
261
262
new torch::jit::Module (torch::jit::load (model_stream)));
263
+
264
+ // Get the device used in the model
265
+ auto parameters = (*torch_model)->parameters ();
266
+ auto buffers = (*torch_model)->buffers ();
267
+
268
+ for (const auto & parameter : parameters) {
269
+ if (parameter.device ().type () != torch::kCPU ) {
270
+ device_id_set.insert (parameter.device ().index ());
271
+ }
272
+ }
273
+ for (const auto & buffer : buffers) {
274
+ if (buffer.device ().type () != torch::kCPU ) {
275
+ device_id_set.insert (buffer.device ().index ());
276
+ }
277
+ }
262
278
} else {
263
279
torch_model->reset (
264
280
new torch::jit::Module (torch::jit::load (model_stream, device)));
@@ -566,6 +582,13 @@ class ModelInstanceState : public BackendModelInstance {
566
582
cudaEvent_t compute_input_start_event_;
567
583
cudaEvent_t compute_infer_start_event_;
568
584
cudaEvent_t compute_output_start_event_;
585
+
586
+ // Store the GPU device ID used in a model for the instance group of type'
587
+ // MODEL'.
588
+ std::unordered_set<int > device_id_set_;
589
+ // Store the extra cuda stream created for the instance group of type' MODEL'
590
+ // and use device ID as the key.
591
+ std::unordered_map<int , cudaStream_t> stream_map_;
569
592
};
570
593
571
594
TRITONSERVER_Error*
@@ -594,10 +617,43 @@ ModelInstanceState::ModelInstanceState(
594
617
if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
595
618
#ifdef TRITON_ENABLE_GPU
596
619
device_ = torch::Device (torch::kCUDA , DeviceId ());
620
+ #endif
621
+ }
622
+
623
+ THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
624
+ ArtifactFilename (), device_, &model_path_, Kind (), device_id_set_,
625
+ &torch_model_));
626
+
627
+ #ifdef TRITON_ENABLE_GPU
628
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
629
+ // Only set the torch device and create a CUDA stream if the model uses GPU.
630
+ if (!device_id_set_.empty ()) {
631
+ auto it = device_id_set_.begin ();
632
+ // Use the first device to create the default stream.
633
+ THROW_IF_BACKEND_INSTANCE_ERROR (
634
+ CreateCudaStream (*it, 0 /* cuda_stream_priority */ , &stream_));
635
+ device_ = torch::Device (torch::kCUDA , *it);
636
+
637
+ // Create a CUDA stream for other devices so that they can be synchronized
638
+ // later. Skip the first device since it is used to create the default
639
+ // stream.
640
+ if (it != device_id_set_.end ()) {
641
+ ++it;
642
+ }
643
+ for (; it != device_id_set_.end (); ++it) {
644
+ cudaStream_t stream;
645
+ THROW_IF_BACKEND_INSTANCE_ERROR (
646
+ CreateCudaStream (*it, 0 /* cuda_stream_priority */ , &stream));
647
+ stream_map_.insert ({*it, stream});
648
+ }
649
+ }
650
+ }
651
+
652
+ if (device_.is_cuda ()) {
597
653
// Need to set the CUDA context so that the context that events are
598
654
// created on match with contexts that events are recorded with.
599
655
THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
600
- cudaSetDevice (DeviceId ()), TRITONSERVER_ERROR_INTERNAL,
656
+ cudaSetDevice (device_. index ()), TRITONSERVER_ERROR_INTERNAL,
601
657
" Failed to set the device" ));
602
658
THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
603
659
cudaEventCreate (&compute_input_start_event_),
@@ -608,11 +664,8 @@ ModelInstanceState::ModelInstanceState(
608
664
THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
609
665
cudaEventCreate (&compute_output_start_event_),
610
666
TRITONSERVER_ERROR_INTERNAL, " Failed to create cuda event" ));
611
- #endif
612
667
}
613
-
614
- THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
615
- ArtifactFilename (), device_, &model_path_, Kind (), &torch_model_));
668
+ #endif
616
669
617
670
size_t expected_input_cnt = 0 ;
618
671
{
@@ -680,6 +733,21 @@ ModelInstanceState::~ModelInstanceState()
680
733
{
681
734
torch_model_.reset ();
682
735
ClearCache ();
736
+
737
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
738
+ for (auto & m : stream_map_) {
739
+ cudaSetDevice (m.first );
740
+ cudaError_t err = cudaStreamDestroy (m.second );
741
+ if (err != cudaSuccess) {
742
+ TRITONSERVER_LogMessage (
743
+ TRITONSERVER_LOG_ERROR, __FILE__, __LINE__,
744
+ (std::string (" ~ModelInstanceState: " ) + name_ +
745
+ " failed to destroy cuda stream: " + cudaGetErrorString (err))
746
+ .c_str ());
747
+ }
748
+ m.second = nullptr ;
749
+ }
750
+ }
683
751
}
684
752
685
753
TRITONSERVER_Error*
@@ -1039,13 +1107,16 @@ ModelInstanceState::ProcessRequests(
1039
1107
std::to_string (request_count) + " requests" )
1040
1108
.c_str ());
1041
1109
1042
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1043
1110
#ifdef TRITON_ENABLE_GPU
1044
- at::cuda::CUDAStream torch_stream =
1045
- at::cuda::getStreamFromExternal (stream_, DeviceId ());
1111
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1112
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && device_.is_cuda ())) {
1113
+ at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromExternal (
1114
+ stream_, (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU)
1115
+ ? DeviceId ()
1116
+ : device_.index ());
1046
1117
at::cuda::setCurrentCUDAStream (torch_stream);
1047
- #endif
1048
1118
}
1119
+ #endif
1049
1120
1050
1121
NVTX_RANGE (nvtx_, " ProcessRequests " + Name ());
1051
1122
@@ -1151,7 +1222,8 @@ ModelInstanceState::ProcessRequests(
1151
1222
std::vector<torch::jit::IValue> input_tensors;
1152
1223
bool cuda_copy = false ;
1153
1224
std::unique_ptr<BackendInputCollector> collector;
1154
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1225
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1226
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr )) {
1155
1227
#ifdef TRITON_ENABLE_GPU
1156
1228
RESPOND_ALL_AND_SET_TRUE_IF_ERROR (
1157
1229
responses, request_count, all_response_failed,
@@ -1176,6 +1248,11 @@ ModelInstanceState::ProcessRequests(
1176
1248
#ifdef TRITON_ENABLE_GPU
1177
1249
if (cuda_copy) {
1178
1250
cudaStreamSynchronize (stream_);
1251
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1252
+ for (auto & m : stream_map_) {
1253
+ cudaStreamSynchronize (m.second );
1254
+ }
1255
+ }
1179
1256
cuda_copy = false ;
1180
1257
}
1181
1258
#endif
@@ -1253,7 +1330,8 @@ ModelInstanceState::ProcessRequests(
1253
1330
1254
1331
// We don't need an explicit CUDA syncrhonization here since we have already
1255
1332
// synchronized the stream in the ReadOutputTensors function.
1256
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1333
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1334
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr )) {
1257
1335
#ifdef TRITON_ENABLE_GPU
1258
1336
// [FIXME] in the case of cudaEventElapsedTime failure, should handle
1259
1337
// stats reporting more gracefully as the durations are inaccurate
@@ -1607,7 +1685,9 @@ SetStringInputTensor(
1607
1685
torch::List<std::string>* input_list, TRITONBACKEND_Input* input,
1608
1686
const char * name, const uint32_t buffer_count,
1609
1687
const size_t request_element_cnt, TRITONBACKEND_Response** response,
1610
- cudaStream_t stream, const char * host_policy_name)
1688
+ cudaStream_t stream,
1689
+ const std::unordered_map<int , cudaStream_t>& stream_map,
1690
+ const char * host_policy_name, const TRITONSERVER_InstanceGroupKind& kind)
1611
1691
{
1612
1692
bool cuda_copy = false ;
1613
1693
size_t element_idx = 0 ;
@@ -1632,6 +1712,11 @@ SetStringInputTensor(
1632
1712
#ifdef TRITON_ENABLE_GPU
1633
1713
if (cuda_copy) {
1634
1714
cudaStreamSynchronize (stream);
1715
+ if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1716
+ for (auto & m : stream_map) {
1717
+ cudaStreamSynchronize (m.second );
1718
+ }
1719
+ }
1635
1720
cuda_copy = false ;
1636
1721
}
1637
1722
#endif // TRITON_ENABLE_GPU
@@ -1819,6 +1904,16 @@ ModelInstanceState::SetInputTensors(
1819
1904
}
1820
1905
}
1821
1906
1907
+ // The input must be in contiguous CPU/GPU memory.
1908
+ std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1909
+ if ((device_.is_cpu ()) ||
1910
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL)) {
1911
+ alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1912
+ {TRITONSERVER_MEMORY_CPU, 0 }};
1913
+ } else {
1914
+ alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
1915
+ }
1916
+
1822
1917
const char * input_buffer;
1823
1918
size_t batchn_byte_size;
1824
1919
TRITONSERVER_MemoryType memory_type;
@@ -1857,7 +1952,8 @@ ModelInstanceState::SetInputTensors(
1857
1952
1858
1953
*cuda_copy |= SetStringInputTensor (
1859
1954
&input_list, input, input_name, buffer_count, batch_element_cnt,
1860
- &((*responses)[idx]), CudaStream (), HostPolicyName ().c_str ());
1955
+ &((*responses)[idx]), CudaStream (), stream_map_,
1956
+ HostPolicyName ().c_str (), Kind ());
1861
1957
}
1862
1958
1863
1959
(*input_tensors)[input_index_map_[input_name]] = input_list;
@@ -2046,6 +2142,11 @@ ModelInstanceState::ReadOutputTensors(
2046
2142
// are only guaranteed to be synchronized if the model provides the output
2047
2143
// on GPU.
2048
2144
cudaStreamSynchronize (stream_);
2145
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
2146
+ for (auto & m : stream_map_) {
2147
+ cudaStreamSynchronize (m.second );
2148
+ }
2149
+ }
2049
2150
#endif
2050
2151
2051
2152
return nullptr ;
@@ -2055,7 +2156,8 @@ TRITONSERVER_Error*
2055
2156
ModelInstanceState::RecordBackendTimestamp (
2056
2157
uint64_t * timestamp, void * cuda_event)
2057
2158
{
2058
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
2159
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
2160
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr )) {
2059
2161
#ifdef TRITON_ENABLE_GPU
2060
2162
cudaEvent_t* lcuda_event = reinterpret_cast <cudaEvent_t*>(cuda_event);
2061
2163
RETURN_IF_ERROR (ConvertCUDAStatusToTritonError (
0 commit comments