@@ -82,6 +82,7 @@ class ModelState : public BackendModel {
82
82
TRITONSERVER_Error* LoadModel (
83
83
const std::string& artifact_name, const torch::Device device,
84
84
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
85
+ std::unordered_set<int >& device_id_set,
85
86
std::shared_ptr<torch::jit::script::Module>* torch_model);
86
87
87
88
bool EnabledOptimizedExecution () { return enable_optimized_execution_; }
@@ -207,6 +208,7 @@ TRITONSERVER_Error*
207
208
ModelState::LoadModel (
208
209
const std::string& artifact_name, const torch::Device device,
209
210
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
211
+ std::unordered_set<int >& device_id_set,
210
212
std::shared_ptr<torch::jit::script::Module>* torch_model)
211
213
{
212
214
// Find the TorchScript file that describes the model. If the model
@@ -257,9 +259,23 @@ ModelState::LoadModel(
257
259
try {
258
260
std::istringstream model_stream (model_data_str);
259
261
if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
260
- // Don't select the device when loading the model.
261
262
torch_model->reset (
262
263
new torch::jit::Module (torch::jit::load (model_stream)));
264
+
265
+ // Get the device used in the model
266
+ auto parameters = (*torch_model)->parameters ();
267
+ auto buffers = (*torch_model)->buffers ();
268
+
269
+ for (const auto & parameter : parameters) {
270
+ if (parameter.device ().type () != torch::kCPU ) {
271
+ device_id_set.insert (parameter.device ().index ());
272
+ }
273
+ }
274
+ for (const auto & buffer : buffers) {
275
+ if (buffer.device ().type () != torch::kCPU ) {
276
+ device_id_set.insert (buffer.device ().index ());
277
+ }
278
+ }
263
279
} else {
264
280
torch_model->reset (
265
281
new torch::jit::Module (torch::jit::load (model_stream, device)));
@@ -567,6 +583,13 @@ class ModelInstanceState : public BackendModelInstance {
567
583
cudaEvent_t compute_input_start_event_;
568
584
cudaEvent_t compute_infer_start_event_;
569
585
cudaEvent_t compute_output_start_event_;
586
+
587
+ // Store the GPU device ID used in a model for the instance group of type'
588
+ // MODEL'.
589
+ std::unordered_set<int > device_id_set_;
590
+ // Store the extra cuda stream created for the instance group of type' MODEL'
591
+ // and use device ID as the key.
592
+ std::unordered_map<int , cudaStream_t> stream_map_;
570
593
};
571
594
572
595
TRITONSERVER_Error*
@@ -595,10 +618,43 @@ ModelInstanceState::ModelInstanceState(
595
618
if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
596
619
#ifdef TRITON_ENABLE_GPU
597
620
device_ = torch::Device (torch::kCUDA , DeviceId ());
621
+ #endif
622
+ }
623
+
624
+ THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
625
+ ArtifactFilename (), device_, &model_path_, Kind (), device_id_set_,
626
+ &torch_model_));
627
+
628
+ #ifdef TRITON_ENABLE_GPU
629
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
630
+ // Only set the torch device and create a CUDA stream if the model uses GPU.
631
+ if (!device_id_set_.empty ()) {
632
+ auto it = device_id_set_.begin ();
633
+ // Use the first device to create the default stream.
634
+ THROW_IF_BACKEND_INSTANCE_ERROR (
635
+ CreateCudaStream (*it, 0 /* cuda_stream_priority */ , &stream_));
636
+ device_ = torch::Device (torch::kCUDA , *it);
637
+
638
+ // Create a CUDA stream for other devices so that they can be synchronized
639
+ // later. Skip the first device since it is used to create the default
640
+ // stream.
641
+ if (it != device_id_set_.end ()) {
642
+ ++it;
643
+ }
644
+ for (; it != device_id_set_.end (); ++it) {
645
+ cudaStream_t stream;
646
+ THROW_IF_BACKEND_INSTANCE_ERROR (
647
+ CreateCudaStream (*it, 0 /* cuda_stream_priority */ , &stream));
648
+ stream_map_.insert ({*it, stream});
649
+ }
650
+ }
651
+ }
652
+
653
+ if (device_.is_cuda ()) {
598
654
// Need to set the CUDA context so that the context that events are
599
655
// created on match with contexts that events are recorded with.
600
656
THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
601
- cudaSetDevice (DeviceId ()), TRITONSERVER_ERROR_INTERNAL,
657
+ cudaSetDevice (device_. index ()), TRITONSERVER_ERROR_INTERNAL,
602
658
" Failed to set the device" ));
603
659
THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
604
660
cudaEventCreate (&compute_input_start_event_),
@@ -609,11 +665,8 @@ ModelInstanceState::ModelInstanceState(
609
665
THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
610
666
cudaEventCreate (&compute_output_start_event_),
611
667
TRITONSERVER_ERROR_INTERNAL, " Failed to create cuda event" ));
612
- #endif
613
668
}
614
-
615
- THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
616
- ArtifactFilename (), device_, &model_path_, Kind (), &torch_model_));
669
+ #endif
617
670
618
671
size_t expected_input_cnt = 0 ;
619
672
{
@@ -681,6 +734,21 @@ ModelInstanceState::~ModelInstanceState()
681
734
{
682
735
torch_model_.reset ();
683
736
ClearCache ();
737
+
738
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
739
+ for (auto & m : stream_map_) {
740
+ cudaSetDevice (m.first );
741
+ cudaError_t err = cudaStreamDestroy (m.second );
742
+ if (err != cudaSuccess) {
743
+ TRITONSERVER_LogMessage (
744
+ TRITONSERVER_LOG_ERROR, __FILE__, __LINE__,
745
+ (std::string (" ~ModelInstanceState: " ) + name_ +
746
+ " failed to destroy cuda stream: " + cudaGetErrorString (err))
747
+ .c_str ());
748
+ }
749
+ m.second = nullptr ;
750
+ }
751
+ }
684
752
}
685
753
686
754
TRITONSERVER_Error*
@@ -1040,13 +1108,16 @@ ModelInstanceState::ProcessRequests(
1040
1108
std::to_string (request_count) + " requests" )
1041
1109
.c_str ());
1042
1110
1043
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1044
1111
#ifdef TRITON_ENABLE_GPU
1045
- at::cuda::CUDAStream torch_stream =
1046
- at::cuda::getStreamFromExternal (stream_, DeviceId ());
1112
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1113
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && device_.is_cuda ())) {
1114
+ at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromExternal (
1115
+ stream_, (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU)
1116
+ ? DeviceId ()
1117
+ : device_.index ());
1047
1118
at::cuda::setCurrentCUDAStream (torch_stream);
1048
- #endif
1049
1119
}
1120
+ #endif
1050
1121
1051
1122
NVTX_RANGE (nvtx_, " ProcessRequests " + Name ());
1052
1123
@@ -1152,7 +1223,8 @@ ModelInstanceState::ProcessRequests(
1152
1223
std::vector<torch::jit::IValue> input_tensors;
1153
1224
bool cuda_copy = false ;
1154
1225
std::unique_ptr<BackendInputCollector> collector;
1155
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1226
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1227
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr )) {
1156
1228
#ifdef TRITON_ENABLE_GPU
1157
1229
RESPOND_ALL_AND_SET_TRUE_IF_ERROR (
1158
1230
responses, request_count, all_response_failed,
@@ -1177,6 +1249,11 @@ ModelInstanceState::ProcessRequests(
1177
1249
#ifdef TRITON_ENABLE_GPU
1178
1250
if (cuda_copy) {
1179
1251
cudaStreamSynchronize (stream_);
1252
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1253
+ for (auto & m : stream_map_) {
1254
+ cudaStreamSynchronize (m.second );
1255
+ }
1256
+ }
1180
1257
cuda_copy = false ;
1181
1258
}
1182
1259
#endif
@@ -1254,7 +1331,8 @@ ModelInstanceState::ProcessRequests(
1254
1331
1255
1332
// We don't need an explicit CUDA syncrhonization here since we have already
1256
1333
// synchronized the stream in the ReadOutputTensors function.
1257
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1334
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1335
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr )) {
1258
1336
#ifdef TRITON_ENABLE_GPU
1259
1337
// [FIXME] in the case of cudaEventElapsedTime failure, should handle
1260
1338
// stats reporting more gracefully as the durations are inaccurate
@@ -1608,7 +1686,9 @@ SetStringInputTensor(
1608
1686
torch::List<std::string>* input_list, TRITONBACKEND_Input* input,
1609
1687
const char * name, const uint32_t buffer_count,
1610
1688
const size_t request_element_cnt, TRITONBACKEND_Response** response,
1611
- cudaStream_t stream, const char * host_policy_name)
1689
+ cudaStream_t stream,
1690
+ const std::unordered_map<int , cudaStream_t>& stream_map,
1691
+ const char * host_policy_name, const TRITONSERVER_InstanceGroupKind& kind)
1612
1692
{
1613
1693
bool cuda_copy = false ;
1614
1694
size_t element_idx = 0 ;
@@ -1633,6 +1713,11 @@ SetStringInputTensor(
1633
1713
#ifdef TRITON_ENABLE_GPU
1634
1714
if (cuda_copy) {
1635
1715
cudaStreamSynchronize (stream);
1716
+ if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1717
+ for (auto & m : stream_map) {
1718
+ cudaStreamSynchronize (m.second );
1719
+ }
1720
+ }
1636
1721
cuda_copy = false ;
1637
1722
}
1638
1723
#endif // TRITON_ENABLE_GPU
@@ -1812,7 +1897,8 @@ ModelInstanceState::SetInputTensors(
1812
1897
1813
1898
// The input must be in contiguous CPU/GPU memory.
1814
1899
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1815
- if (device_.is_cpu ()) {
1900
+ if ((device_.is_cpu ()) ||
1901
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL)) {
1816
1902
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1817
1903
{TRITONSERVER_MEMORY_CPU, 0 }};
1818
1904
} else {
@@ -1857,7 +1943,8 @@ ModelInstanceState::SetInputTensors(
1857
1943
1858
1944
*cuda_copy |= SetStringInputTensor (
1859
1945
&input_list, input, input_name, buffer_count, batch_element_cnt,
1860
- &((*responses)[idx]), CudaStream (), HostPolicyName ().c_str ());
1946
+ &((*responses)[idx]), CudaStream (), stream_map_,
1947
+ HostPolicyName ().c_str (), Kind ());
1861
1948
}
1862
1949
1863
1950
(*input_tensors)[input_index_map_[input_name]] = input_list;
@@ -2045,6 +2132,11 @@ ModelInstanceState::ReadOutputTensors(
2045
2132
// are only guaranteed to be synchronized if the model provides the output
2046
2133
// on GPU.
2047
2134
cudaStreamSynchronize (stream_);
2135
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
2136
+ for (auto & m : stream_map_) {
2137
+ cudaStreamSynchronize (m.second );
2138
+ }
2139
+ }
2048
2140
#endif
2049
2141
2050
2142
return nullptr ;
@@ -2054,7 +2146,8 @@ TRITONSERVER_Error*
2054
2146
ModelInstanceState::RecordBackendTimestamp (
2055
2147
uint64_t * timestamp, void * cuda_event)
2056
2148
{
2057
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
2149
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
2150
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr )) {
2058
2151
#ifdef TRITON_ENABLE_GPU
2059
2152
cudaEvent_t* lcuda_event = reinterpret_cast <cudaEvent_t*>(cuda_event);
2060
2153
RETURN_IF_ERROR (ConvertCUDAStatusToTritonError (
0 commit comments