@@ -79,6 +79,7 @@ class ModelState : public BackendModel {
79
79
TRITONSERVER_Error* LoadModel (
80
80
const std::string& artifact_name, const torch::Device device,
81
81
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
82
+ std::unordered_set<int >& device_id_set,
82
83
std::shared_ptr<torch::jit::script::Module>* torch_model);
83
84
84
85
bool EnabledOptimizedExecution () { return enable_optimized_execution_; }
@@ -204,6 +205,7 @@ TRITONSERVER_Error*
204
205
ModelState::LoadModel (
205
206
const std::string& artifact_name, const torch::Device device,
206
207
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
208
+ std::unordered_set<int >& device_id_set,
207
209
std::shared_ptr<torch::jit::script::Module>* torch_model)
208
210
{
209
211
// Find the TorchScript file that describes the model. If the model
@@ -254,9 +256,23 @@ ModelState::LoadModel(
254
256
try {
255
257
std::istringstream model_stream (model_data_str);
256
258
if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
257
- // Don't select the device when loading the model.
258
259
torch_model->reset (
259
260
new torch::jit::Module (torch::jit::load (model_stream)));
261
+
262
+ // Get the device used in the model
263
+ auto parameters = (*torch_model)->parameters ();
264
+ auto buffers = (*torch_model)->buffers ();
265
+
266
+ for (const auto & parameter : parameters) {
267
+ if (parameter.device ().type () != torch::kCPU ) {
268
+ device_id_set.insert (parameter.device ().index ());
269
+ }
270
+ }
271
+ for (const auto & buffer : buffers) {
272
+ if (buffer.device ().type () != torch::kCPU ) {
273
+ device_id_set.insert (buffer.device ().index ());
274
+ }
275
+ }
260
276
} else {
261
277
torch_model->reset (
262
278
new torch::jit::Module (torch::jit::load (model_stream, device)));
@@ -559,6 +575,13 @@ class ModelInstanceState : public BackendModelInstance {
559
575
cudaEvent_t compute_input_start_event_;
560
576
cudaEvent_t compute_infer_start_event_;
561
577
cudaEvent_t compute_output_start_event_;
578
+
579
+ // Store the GPU device ID used in a model for the instance group of type'
580
+ // MODEL'.
581
+ std::unordered_set<int > device_id_set_;
582
+ // Store the extra cuda stream created for the instance group of type' MODEL'
583
+ // and use device ID as the key.
584
+ std::unordered_map<int , cudaStream_t> stream_map_;
562
585
};
563
586
564
587
TRITONSERVER_Error*
@@ -587,10 +610,43 @@ ModelInstanceState::ModelInstanceState(
587
610
if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
588
611
#ifdef TRITON_ENABLE_GPU
589
612
device_ = torch::Device (torch::kCUDA , DeviceId ());
613
+ #endif
614
+ }
615
+
616
+ THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
617
+ ArtifactFilename (), device_, &model_path_, Kind (), device_id_set_,
618
+ &torch_model_));
619
+
620
+ #ifdef TRITON_ENABLE_GPU
621
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
622
+ // Only set the torch device and create a CUDA stream if the model uses GPU.
623
+ if (!device_id_set_.empty ()) {
624
+ auto it = device_id_set_.begin ();
625
+ // Use the first device to create the default stream.
626
+ THROW_IF_BACKEND_INSTANCE_ERROR (
627
+ CreateCudaStream (*it, 0 /* cuda_stream_priority */ , &stream_));
628
+ device_ = torch::Device (torch::kCUDA , *it);
629
+
630
+ // Create a CUDA stream for other devices so that they can be synchronized
631
+ // later. Skip the first device since it is used to create the default
632
+ // stream.
633
+ if (it != device_id_set_.end ()) {
634
+ ++it;
635
+ }
636
+ for (; it != device_id_set_.end (); ++it) {
637
+ cudaStream_t stream;
638
+ THROW_IF_BACKEND_INSTANCE_ERROR (
639
+ CreateCudaStream (*it, 0 /* cuda_stream_priority */ , &stream));
640
+ stream_map_.insert ({*it, stream});
641
+ }
642
+ }
643
+ }
644
+
645
+ if (device_.is_cuda ()) {
590
646
// Need to set the CUDA context so that the context that events are
591
647
// created on match with contexts that events are recorded with.
592
648
THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
593
- cudaSetDevice (DeviceId ()), TRITONSERVER_ERROR_INTERNAL,
649
+ cudaSetDevice (device_. index ()), TRITONSERVER_ERROR_INTERNAL,
594
650
" Failed to set the device" ));
595
651
THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
596
652
cudaEventCreate (&compute_input_start_event_),
@@ -601,11 +657,8 @@ ModelInstanceState::ModelInstanceState(
601
657
THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
602
658
cudaEventCreate (&compute_output_start_event_),
603
659
TRITONSERVER_ERROR_INTERNAL, " Failed to create cuda event" ));
604
- #endif
605
660
}
606
-
607
- THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
608
- ArtifactFilename (), device_, &model_path_, Kind (), &torch_model_));
661
+ #endif
609
662
610
663
size_t expected_input_cnt = 0 ;
611
664
{
@@ -667,6 +720,21 @@ ModelInstanceState::~ModelInstanceState()
667
720
{
668
721
torch_model_.reset ();
669
722
ClearCache ();
723
+
724
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
725
+ for (auto & m : stream_map_) {
726
+ cudaSetDevice (m.first );
727
+ cudaError_t err = cudaStreamDestroy (m.second );
728
+ if (err != cudaSuccess) {
729
+ TRITONSERVER_LogMessage (
730
+ TRITONSERVER_LOG_ERROR, __FILE__, __LINE__,
731
+ (std::string (" ~ModelInstanceState: " ) + name_ +
732
+ " failed to destroy cuda stream: " + cudaGetErrorString (err))
733
+ .c_str ());
734
+ }
735
+ m.second = nullptr ;
736
+ }
737
+ }
670
738
}
671
739
672
740
TRITONSERVER_Error*
@@ -1006,13 +1074,16 @@ ModelInstanceState::ProcessRequests(
1006
1074
std::to_string (request_count) + " requests" )
1007
1075
.c_str ());
1008
1076
1009
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1010
1077
#ifdef TRITON_ENABLE_GPU
1011
- at::cuda::CUDAStream torch_stream =
1012
- at::cuda::getStreamFromExternal (stream_, DeviceId ());
1078
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1079
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && device_.is_cuda ())) {
1080
+ at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromExternal (
1081
+ stream_, (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU)
1082
+ ? DeviceId ()
1083
+ : device_.index ());
1013
1084
at::cuda::setCurrentCUDAStream (torch_stream);
1014
- #endif
1015
1085
}
1086
+ #endif
1016
1087
1017
1088
NVTX_RANGE (nvtx_, " ProcessRequests " + Name ());
1018
1089
@@ -1118,7 +1189,8 @@ ModelInstanceState::ProcessRequests(
1118
1189
std::vector<torch::jit::IValue> input_tensors;
1119
1190
bool cuda_copy = false ;
1120
1191
std::unique_ptr<BackendInputCollector> collector;
1121
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1192
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1193
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr )) {
1122
1194
#ifdef TRITON_ENABLE_GPU
1123
1195
RESPOND_ALL_AND_SET_TRUE_IF_ERROR (
1124
1196
responses, request_count, all_response_failed,
@@ -1143,6 +1215,11 @@ ModelInstanceState::ProcessRequests(
1143
1215
#ifdef TRITON_ENABLE_GPU
1144
1216
if (cuda_copy) {
1145
1217
cudaStreamSynchronize (stream_);
1218
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1219
+ for (auto & m : stream_map_) {
1220
+ cudaStreamSynchronize (m.second );
1221
+ }
1222
+ }
1146
1223
cuda_copy = false ;
1147
1224
}
1148
1225
#endif
@@ -1220,7 +1297,8 @@ ModelInstanceState::ProcessRequests(
1220
1297
1221
1298
// We don't need an explicit CUDA syncrhonization here since we have already
1222
1299
// synchronized the stream in the ReadOutputTensors function.
1223
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1300
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1301
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr )) {
1224
1302
#ifdef TRITON_ENABLE_GPU
1225
1303
// [FIXME] in the case of cudaEventElapsedTime failure, should handle
1226
1304
// stats reporting more gracefully as the durations are inaccurate
@@ -1574,7 +1652,9 @@ SetStringInputTensor(
1574
1652
torch::List<std::string>* input_list, TRITONBACKEND_Input* input,
1575
1653
const char * name, const uint32_t buffer_count,
1576
1654
const size_t request_element_cnt, TRITONBACKEND_Response** response,
1577
- cudaStream_t stream, const char * host_policy_name)
1655
+ cudaStream_t stream,
1656
+ const std::unordered_map<int , cudaStream_t>& stream_map,
1657
+ const char * host_policy_name, const TRITONSERVER_InstanceGroupKind& kind)
1578
1658
{
1579
1659
bool cuda_copy = false ;
1580
1660
size_t element_idx = 0 ;
@@ -1599,6 +1679,11 @@ SetStringInputTensor(
1599
1679
#ifdef TRITON_ENABLE_GPU
1600
1680
if (cuda_copy) {
1601
1681
cudaStreamSynchronize (stream);
1682
+ if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1683
+ for (auto & m : stream_map) {
1684
+ cudaStreamSynchronize (m.second );
1685
+ }
1686
+ }
1602
1687
cuda_copy = false ;
1603
1688
}
1604
1689
#endif // TRITON_ENABLE_GPU
@@ -1777,7 +1862,8 @@ ModelInstanceState::SetInputTensors(
1777
1862
1778
1863
// The input must be in contiguous CPU/GPU memory.
1779
1864
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1780
- if (device_.is_cpu ()) {
1865
+ if ((device_.is_cpu ()) ||
1866
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL)) {
1781
1867
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1782
1868
{TRITONSERVER_MEMORY_CPU, 0 }};
1783
1869
} else {
@@ -1822,7 +1908,8 @@ ModelInstanceState::SetInputTensors(
1822
1908
1823
1909
*cuda_copy |= SetStringInputTensor (
1824
1910
&input_list, input, input_name, buffer_count, batch_element_cnt,
1825
- &((*responses)[idx]), CudaStream (), HostPolicyName ().c_str ());
1911
+ &((*responses)[idx]), CudaStream (), stream_map_,
1912
+ HostPolicyName ().c_str (), Kind ());
1826
1913
}
1827
1914
1828
1915
(*input_tensors)[input_index_map_[input_name]] = input_list;
@@ -1980,6 +2067,11 @@ ModelInstanceState::ReadOutputTensors(
1980
2067
// are only guaranteed to be synchronized if the model provides the output
1981
2068
// on GPU.
1982
2069
cudaStreamSynchronize (stream_);
2070
+ if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
2071
+ for (auto & m : stream_map_) {
2072
+ cudaStreamSynchronize (m.second );
2073
+ }
2074
+ }
1983
2075
#endif
1984
2076
1985
2077
return nullptr ;
@@ -1989,7 +2081,8 @@ TRITONSERVER_Error*
1989
2081
ModelInstanceState::RecordBackendTimestamp (
1990
2082
uint64_t * timestamp, void * cuda_event)
1991
2083
{
1992
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
2084
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
2085
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL && stream_ != nullptr )) {
1993
2086
#ifdef TRITON_ENABLE_GPU
1994
2087
cudaEvent_t* lcuda_event = reinterpret_cast <cudaEvent_t*>(cuda_event);
1995
2088
RETURN_IF_ERROR (ConvertCUDAStatusToTritonError (
0 commit comments