@@ -64,29 +64,10 @@ namespace triton { namespace backend { namespace pytorch {
64
64
namespace {
65
65
66
66
#ifdef TRITON_ENABLE_GPU
67
- void CUDART_CB
68
- CaptureFirstTimestampCallback (void * data)
69
- {
70
- auto * tuple = reinterpret_cast <std::tuple<uint64_t *, std::mutex*>*>(data);
71
-
72
- uint64_t * timestamp = std::get<0 >(*tuple);
73
- std::mutex* mu = std::get<1 >(*tuple);
74
-
75
- std::lock_guard<std::mutex> lock (*mu);
76
- if (*timestamp == 0 ) {
77
- SET_TIMESTAMP (*timestamp);
78
- }
79
- }
80
-
81
67
void CUDART_CB
82
68
CaptureLastTimestampCallback (void * data)
83
69
{
84
- auto * tuple = reinterpret_cast <std::tuple<uint64_t *, std::mutex*>*>(data);
85
-
86
- uint64_t * timestamp = std::get<0 >(*tuple);
87
- std::mutex* mu = std::get<1 >(*tuple);
88
-
89
- std::lock_guard<std::mutex> lock (*mu);
70
+ auto * timestamp = reinterpret_cast <std::atomic<uint64_t >*>(data);
90
71
SET_TIMESTAMP (*timestamp);
91
72
}
92
73
#endif
@@ -571,6 +552,9 @@ class ModelInstanceState : public BackendModelInstance {
571
552
NamingConvention* naming_convention,
572
553
const std::vector<std::string>& allowed_io);
573
554
555
+ // Create CUDA events for statistics collection.
556
+ void CreateCudaEvents (const int32_t & device_id);
557
+
574
558
// Get the appropriate CUDA stream for input and output handling based on the
575
559
// instance group type.
576
560
cudaStream_t GetCudaStreamByInstanceKind ();
@@ -580,6 +564,10 @@ class ModelInstanceState : public BackendModelInstance {
580
564
void SetCurrentCudaStream (
581
565
const cudaStream_t& stream, const int32_t & device_id);
582
566
567
+ // Get the elapsed time between two CUDA events.
568
+ float GetCudaEventElapsedTime (
569
+ const cudaEvent_t& start_event, const cudaEvent_t& end_event);
570
+
583
571
ModelState* model_state_;
584
572
585
573
// The full path to the TorchScript model file.
@@ -610,6 +598,9 @@ class ModelInstanceState : public BackendModelInstance {
610
598
611
599
// Store the cuda streams created for the 'KIND_MODEL' instance group.
612
600
std::vector<cudaStream_t> stream_vec_;
601
+
602
+ // The number of available devices.
603
+ int device_cnt_;
613
604
};
614
605
615
606
TRITONSERVER_Error*
@@ -633,47 +624,37 @@ ModelInstanceState::Create(
633
624
ModelInstanceState::ModelInstanceState (
634
625
ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance)
635
626
: BackendModelInstance(model_state, triton_model_instance),
636
- model_state_(model_state), device_(torch::kCPU ), is_dict_input_(false )
627
+ model_state_(model_state), device_(torch::kCPU ), is_dict_input_(false ),
628
+ device_cnt_(0 )
637
629
{
638
630
if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
639
631
#ifdef TRITON_ENABLE_GPU
640
632
device_ = torch::Device (torch::kCUDA , DeviceId ());
641
- // Need to set the CUDA context so that the context that events are
642
- // created on match with contexts that events are recorded with.
643
- THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
644
- cudaSetDevice (DeviceId ()), TRITONSERVER_ERROR_INTERNAL,
645
- " Failed to set the device" ));
646
- THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
647
- cudaEventCreate (&compute_input_start_event_),
648
- TRITONSERVER_ERROR_INTERNAL, " Failed to create cuda event" ));
649
- THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
650
- cudaEventCreate (&compute_infer_start_event_),
651
- TRITONSERVER_ERROR_INTERNAL, " Failed to create cuda event" ));
652
- THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
653
- cudaEventCreate (&compute_output_start_event_),
654
- TRITONSERVER_ERROR_INTERNAL, " Failed to create cuda event" ));
633
+ CreateCudaEvents (DeviceId ());
655
634
#endif
656
635
}
657
636
637
+ #ifdef TRITON_ENABLE_GPU
638
+ device_cnt_ = torch::cuda::device_count ();
639
+ #endif
640
+
658
641
THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
659
642
ArtifactFilename (), device_, &model_path_, Kind (), &torch_model_));
660
643
661
644
if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
662
645
#ifdef TRITON_ENABLE_GPU
663
- // Since we cannot determine the exact devices used by the model, we create
664
- // a CUDA stream for every available device to ensure proper synchronization
665
- // of CUDA streams. This approach may have implications when a timestamp is
666
- // captured on a device that is not used by the model. Currently, this issue
667
- // is addressed by synchronizing the CUDA streams before recording
668
- // timestamps to prevent timestamp skewing. However, in the future, any
669
- // modifications to the CUDA stream synchronization logic should be handled
670
- // with caution.
671
- for (int i = 0 ; i < torch::cuda::device_count (); i++) {
646
+ // Create a CUDA stream for every availble device.
647
+ for (int i = 0 ; i < device_cnt_; i++) {
672
648
cudaStream_t stream;
673
649
THROW_IF_BACKEND_INSTANCE_ERROR (
674
650
CreateCudaStream (i, 0 /* cuda_stream_priority */ , &stream));
675
651
stream_vec_.push_back (stream);
676
652
}
653
+ if (!stream_vec_.empty ()) {
654
+ // Create CUDA events on the first device that will be used for collecting
655
+ // inputs/outputs.
656
+ CreateCudaEvents (0 );
657
+ }
677
658
#endif
678
659
}
679
660
733
714
ModelInstanceState::ClearCache ()
734
715
{
735
716
#ifdef TRITON_ENABLE_GPU
736
- if (device_.is_cuda () || (( Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) &&
737
- ( torch::cuda::device_count () > 0 ))) {
717
+ if (device_.is_cuda () ||
718
+ (( Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0 ))) {
738
719
c10::cuda::CUDACachingAllocator::emptyCache ();
739
720
}
740
721
#endif // TRITON_ENABLE_GPU
@@ -1237,27 +1218,21 @@ ModelInstanceState::ProcessRequests(
1237
1218
std::vector<torch::jit::IValue> input_tensors;
1238
1219
bool cuda_copy = false ;
1239
1220
std::unique_ptr<BackendInputCollector> collector;
1240
- std::mutex timestamp_mu;
1241
1221
1242
- uint64_t compute_input_start = 0 ;
1243
- std::tuple<uint64_t *, std::mutex*> compute_input_cb_data (
1244
- &compute_input_start, ×tamp_mu);
1245
-
1246
- if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1222
+ // For 'KIND_MODEL', it's fine to use CUDA events to calculate the compute
1223
+ // input duration since only one stream will be used for input collection.
1224
+ // However, for the compute infer duration, multiple streams will be involved,
1225
+ // so we need to launch a CUDA callback function for timestamp capturing, as
1226
+ // demonstrated in the 'RecordBackendTimestamp' function.
1227
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1228
+ ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0 ))) {
1247
1229
#ifdef TRITON_ENABLE_GPU
1248
1230
RESPOND_ALL_AND_SET_TRUE_IF_ERROR (
1249
1231
responses, request_count, all_response_failed,
1250
1232
ConvertCUDAStatusToTritonError (
1251
- cudaEventRecord (compute_input_start_event_, stream_),
1233
+ cudaEventRecord (
1234
+ compute_input_start_event_, GetCudaStreamByInstanceKind ()),
1252
1235
TRITONSERVER_ERROR_INTERNAL, " Failed to record the event." ));
1253
- #endif
1254
- } else if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1255
- #ifdef TRITON_ENABLE_GPU
1256
- for (const auto & stream : stream_vec_) {
1257
- cudaLaunchHostFunc (
1258
- stream, CaptureFirstTimestampCallback,
1259
- reinterpret_cast <void *>(&compute_input_cb_data));
1260
- }
1261
1236
#endif
1262
1237
}
1263
1238
@@ -1283,16 +1258,28 @@ ModelInstanceState::ProcessRequests(
1283
1258
1284
1259
std::vector<torch::jit::IValue> output_tensors;
1285
1260
uint64_t compute_start_ns = 0 ;
1286
- uint64_t compute_infer_start = 0 ;
1261
+ std::atomic<uint64_t > compute_infer_start = 0 ;
1262
+
1263
+ // Record 'compute_infer_start_event_' for 'KIND_MODEL' to calculate the
1264
+ // compute input duration. The compute infer start timestamp will be recorded
1265
+ // in the 'RecordBackendTimestamp' function.
1266
+ if ((Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0 )) {
1267
+ #ifdef TRITON_ENABLE_GPU
1268
+ RESPOND_ALL_AND_SET_TRUE_IF_ERROR (
1269
+ responses, request_count, all_response_failed,
1270
+ ConvertCUDAStatusToTritonError (
1271
+ cudaEventRecord (
1272
+ compute_infer_start_event_, GetCudaStreamByInstanceKind ()),
1273
+ TRITONSERVER_ERROR_INTERNAL, " Failed to record the event." ));
1274
+ #endif
1275
+ }
1287
1276
1288
- std::tuple<uint64_t *, std::mutex*> compute_infer_cb_data (
1289
- &compute_infer_start, ×tamp_mu);
1290
1277
RESPOND_ALL_AND_SET_TRUE_IF_ERROR (
1291
1278
responses, request_count, all_response_failed,
1292
1279
RecordBackendTimestamp (
1293
1280
&compute_start_ns,
1294
1281
reinterpret_cast <void *>(&compute_infer_start_event_),
1295
- reinterpret_cast <void *>(&compute_infer_cb_data )));
1282
+ reinterpret_cast <void *>(&compute_infer_start )));
1296
1283
1297
1284
// Run...
1298
1285
if (!all_response_failed) {
@@ -1324,16 +1311,14 @@ ModelInstanceState::ProcessRequests(
1324
1311
}
1325
1312
1326
1313
uint64_t compute_end_ns = 0 ;
1327
- uint64_t compute_output_start = 0 ;
1328
- std::tuple<uint64_t *, std::mutex*> compute_output_cb_data (
1329
- &compute_output_start, ×tamp_mu);
1314
+ std::atomic<uint64_t > compute_output_start = 0 ;
1330
1315
1331
1316
RESPOND_ALL_AND_SET_TRUE_IF_ERROR (
1332
1317
responses, request_count, all_response_failed,
1333
1318
RecordBackendTimestamp (
1334
1319
&compute_end_ns,
1335
1320
reinterpret_cast <void *>(&compute_output_start_event_),
1336
- reinterpret_cast <void *>(&compute_output_cb_data )));
1321
+ reinterpret_cast <void *>(&compute_output_start )));
1337
1322
1338
1323
#ifdef TRITON_ENABLE_GPU
1339
1324
if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
@@ -1373,35 +1358,25 @@ ModelInstanceState::ProcessRequests(
1373
1358
// synchronized the stream in the ReadOutputTensors function.
1374
1359
if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1375
1360
#ifdef TRITON_ENABLE_GPU
1376
- // [FIXME] in the case of cudaEventElapsedTime failure, should handle
1377
- // stats reporting more gracefully as the durations are inaccurate
1378
- float compute_input_duration = 0 ;
1379
- float compute_infer_duration = 0 ;
1380
- LOG_IF_ERROR (
1381
- ConvertCUDAStatusToTritonError (
1382
- cudaEventElapsedTime (
1383
- &compute_input_duration, compute_input_start_event_,
1384
- compute_infer_start_event_),
1385
- TRITONSERVER_ERROR_INTERNAL, " Failed to capture elapsed time" ),
1386
- " Failed to capture elapsed time" );
1387
-
1388
- LOG_IF_ERROR (
1389
- ConvertCUDAStatusToTritonError (
1390
- cudaEventElapsedTime (
1391
- &compute_infer_duration, compute_infer_start_event_,
1392
- compute_output_start_event_),
1393
- TRITONSERVER_ERROR_INTERNAL, " Failed to capture elapsed time" ),
1394
- " Failed to capture elapsed time" );
1361
+ float compute_input_duration = GetCudaEventElapsedTime (
1362
+ compute_input_start_event_, compute_infer_start_event_);
1363
+ float compute_infer_duration = GetCudaEventElapsedTime (
1364
+ compute_infer_start_event_, compute_output_start_event_);
1395
1365
1396
1366
compute_start_ns = exec_start_ns + (compute_input_duration * 1e6 );
1397
1367
compute_end_ns = compute_start_ns + (compute_infer_duration * 1e6 );
1398
1368
#endif
1399
- } else if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1400
- uint64_t compute_input_duration = compute_infer_start - compute_input_start;
1369
+ } else if (
1370
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0 )) {
1371
+ #ifdef TRITON_ENABLE_GPU
1372
+ float compute_input_duration = GetCudaEventElapsedTime (
1373
+ compute_input_start_event_, compute_infer_start_event_);
1401
1374
uint64_t compute_infer_duration =
1402
1375
compute_output_start - compute_infer_start;
1403
- compute_start_ns = exec_start_ns + compute_input_duration;
1376
+
1377
+ compute_start_ns = exec_start_ns + (compute_input_duration * 1e6 );
1404
1378
compute_end_ns = compute_start_ns + compute_infer_duration;
1379
+ #endif
1405
1380
}
1406
1381
1407
1382
// Report statistics for each request.
@@ -1473,7 +1448,7 @@ ModelInstanceState::Execute(
1473
1448
bool is_device_gpu =
1474
1449
(device_.is_cuda () ||
1475
1450
((Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) &&
1476
- (torch::cuda::device_count () > 0 )));
1451
+ (device_cnt_ > 0 )));
1477
1452
if (std::get<1 >(model_state_->EnabledNvfuserPair ()) && is_device_gpu) {
1478
1453
torch::jit::overrideCanFuseOnCPU (false );
1479
1454
torch::jit::overrideCanFuseOnGPU (false );
@@ -2195,17 +2170,17 @@ ModelInstanceState::RecordBackendTimestamp(
2195
2170
uint64_t * timestamp, void * cuda_event, void * timestamp_cb_data)
2196
2171
{
2197
2172
// For the 'KIND_GPU' instance group, we use a CUDA event to record the
2198
- // timestamp. For the 'KIND_MODEL' instance group, it is complicated to
2199
- // calculate the elapsed time between two cuda events from different devices,
2200
- // so we launch a CUDA callback function to record the timestamp.
2173
+ // timestamp. For the 'KIND_MODEL' instance group, launch a CUDA callback
2174
+ // function to record the timestamp for multiple streams.
2201
2175
if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
2202
2176
#ifdef TRITON_ENABLE_GPU
2203
2177
cudaEvent_t* lcuda_event = reinterpret_cast <cudaEvent_t*>(cuda_event);
2204
2178
RETURN_IF_ERROR (ConvertCUDAStatusToTritonError (
2205
2179
cudaEventRecord (*lcuda_event, stream_), TRITONSERVER_ERROR_INTERNAL,
2206
2180
" Failed to record the event." ));
2207
2181
#endif
2208
- } else if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
2182
+ } else if (
2183
+ (Kind () == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0 )) {
2209
2184
#ifdef TRITON_ENABLE_GPU
2210
2185
for (const auto & stream : stream_vec_) {
2211
2186
cudaLaunchHostFunc (
@@ -2219,17 +2194,23 @@ ModelInstanceState::RecordBackendTimestamp(
2219
2194
}
2220
2195
2221
2196
void
2222
- ModelInstanceState::SetCurrentCudaStream (
2223
- const cudaStream_t& stream, const int & device_id)
2197
+ ModelInstanceState::CreateCudaEvents (const int32_t & device_id)
2224
2198
{
2225
2199
#ifdef TRITON_ENABLE_GPU
2226
- at::cuda::CUDAStream torch_stream =
2227
- at::cuda::getStreamFromExternal (stream, device_id);
2228
- // This function replaces the default stream with the stream we created. It
2229
- // is not necessary to change the current device to the desired device when
2230
- // replacing the default stream for that device. See the documentation here:
2231
- // https://pytorch.org/cppdocs/api/function_namespacec10_1_1cuda_1a6ed50cc0fc16cc7014d9c2f4c3bd098d.html
2232
- at::cuda::setCurrentCUDAStream (torch_stream);
2200
+ // Need to set the CUDA context so that the context that events are
2201
+ // created on match with contexts that events are recorded with.
2202
+ THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
2203
+ cudaSetDevice (device_id), TRITONSERVER_ERROR_INTERNAL,
2204
+ " Failed to set the device" ));
2205
+ THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
2206
+ cudaEventCreate (&compute_input_start_event_), TRITONSERVER_ERROR_INTERNAL,
2207
+ " Failed to create cuda event" ));
2208
+ THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
2209
+ cudaEventCreate (&compute_infer_start_event_), TRITONSERVER_ERROR_INTERNAL,
2210
+ " Failed to create cuda event" ));
2211
+ THROW_IF_BACKEND_INSTANCE_ERROR (ConvertCUDAStatusToTritonError (
2212
+ cudaEventCreate (&compute_output_start_event_),
2213
+ TRITONSERVER_ERROR_INTERNAL, " Failed to create cuda event" ));
2233
2214
#endif
2234
2215
}
2235
2216
@@ -2248,6 +2229,36 @@ ModelInstanceState::GetCudaStreamByInstanceKind()
2248
2229
return nullptr ;
2249
2230
}
2250
2231
2232
+ void
2233
+ ModelInstanceState::SetCurrentCudaStream (
2234
+ const cudaStream_t& stream, const int & device_id)
2235
+ {
2236
+ #ifdef TRITON_ENABLE_GPU
2237
+ at::cuda::CUDAStream torch_stream =
2238
+ at::cuda::getStreamFromExternal (stream, device_id);
2239
+ // This function replaces the default stream with the stream we created. It
2240
+ // is not necessary to change the current device to the desired device when
2241
+ // replacing the default stream for that device. See the documentation here:
2242
+ // https://pytorch.org/cppdocs/api/function_namespacec10_1_1cuda_1a6ed50cc0fc16cc7014d9c2f4c3bd098d.html
2243
+ at::cuda::setCurrentCUDAStream (torch_stream);
2244
+ #endif
2245
+ }
2246
+
2247
+ float
2248
+ ModelInstanceState::GetCudaEventElapsedTime (
2249
+ const cudaEvent_t& start_event, const cudaEvent_t& end_event)
2250
+ {
2251
+ // [FIXME] in the case of cudaEventElapsedTime failure, should handle
2252
+ // stats reporting more gracefully as the durations are inaccurate
2253
+ float duration = 0 ;
2254
+ LOG_IF_ERROR (
2255
+ ConvertCUDAStatusToTritonError (
2256
+ cudaEventElapsedTime (&duration, start_event, end_event),
2257
+ TRITONSERVER_ERROR_INTERNAL, " Failed to capture elapsed time" ),
2258
+ " Failed to capture elapsed time" );
2259
+ return duration;
2260
+ }
2261
+
2251
2262
// ///////////
2252
2263
2253
2264
extern " C" {
0 commit comments