Skip to content

Commit 6e7a066

Browse files
committed
For 'KIND_MODEL', use cuda events for compute_input_duration and use callback for compute_infer_duration
1 parent b1bd8af commit 6e7a066

File tree

1 file changed

+117
-99
lines changed

1 file changed

+117
-99
lines changed

src/libtorch.cc

Lines changed: 117 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -64,29 +64,10 @@ namespace triton { namespace backend { namespace pytorch {
6464
namespace {
6565

6666
#ifdef TRITON_ENABLE_GPU
67-
void CUDART_CB
68-
CaptureFirstTimestampCallback(void* data)
69-
{
70-
auto* tuple = reinterpret_cast<std::tuple<uint64_t*, std::mutex*>*>(data);
71-
72-
uint64_t* timestamp = std::get<0>(*tuple);
73-
std::mutex* mu = std::get<1>(*tuple);
74-
75-
std::lock_guard<std::mutex> lock(*mu);
76-
if (*timestamp == 0) {
77-
SET_TIMESTAMP(*timestamp);
78-
}
79-
}
80-
8167
void CUDART_CB
8268
CaptureLastTimestampCallback(void* data)
8369
{
84-
auto* tuple = reinterpret_cast<std::tuple<uint64_t*, std::mutex*>*>(data);
85-
86-
uint64_t* timestamp = std::get<0>(*tuple);
87-
std::mutex* mu = std::get<1>(*tuple);
88-
89-
std::lock_guard<std::mutex> lock(*mu);
70+
auto* timestamp = reinterpret_cast<std::atomic<uint64_t>*>(data);
9071
SET_TIMESTAMP(*timestamp);
9172
}
9273
#endif
@@ -571,6 +552,9 @@ class ModelInstanceState : public BackendModelInstance {
571552
NamingConvention* naming_convention,
572553
const std::vector<std::string>& allowed_io);
573554

555+
// Create CUDA events for statistics collection.
556+
void CreateCudaEvents(const int32_t& device_id);
557+
574558
// Get the appropriate CUDA stream for input and output handling based on the
575559
// instance group type.
576560
cudaStream_t GetCudaStreamByInstanceKind();
@@ -580,6 +564,10 @@ class ModelInstanceState : public BackendModelInstance {
580564
void SetCurrentCudaStream(
581565
const cudaStream_t& stream, const int32_t& device_id);
582566

567+
// Get the elapsed time between two CUDA events.
568+
float GetCudaEventElapsedTime(
569+
const cudaEvent_t& start_event, const cudaEvent_t& end_event);
570+
583571
ModelState* model_state_;
584572

585573
// The full path to the TorchScript model file.
@@ -610,6 +598,9 @@ class ModelInstanceState : public BackendModelInstance {
610598

611599
// Store the cuda streams created for the 'KIND_MODEL' instance group.
612600
std::vector<cudaStream_t> stream_vec_;
601+
602+
// The number of available devices.
603+
int device_cnt_;
613604
};
614605

615606
TRITONSERVER_Error*
@@ -633,28 +624,20 @@ ModelInstanceState::Create(
633624
ModelInstanceState::ModelInstanceState(
634625
ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance)
635626
: BackendModelInstance(model_state, triton_model_instance),
636-
model_state_(model_state), device_(torch::kCPU), is_dict_input_(false)
627+
model_state_(model_state), device_(torch::kCPU), is_dict_input_(false),
628+
device_cnt_(0)
637629
{
638630
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
639631
#ifdef TRITON_ENABLE_GPU
640632
device_ = torch::Device(torch::kCUDA, DeviceId());
641-
// Need to set the CUDA context so that the context that events are
642-
// created on match with contexts that events are recorded with.
643-
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
644-
cudaSetDevice(DeviceId()), TRITONSERVER_ERROR_INTERNAL,
645-
"Failed to set the device"));
646-
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
647-
cudaEventCreate(&compute_input_start_event_),
648-
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
649-
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
650-
cudaEventCreate(&compute_infer_start_event_),
651-
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
652-
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
653-
cudaEventCreate(&compute_output_start_event_),
654-
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
633+
CreateCudaEvents(DeviceId());
655634
#endif
656635
}
657636

637+
#ifdef TRITON_ENABLE_GPU
638+
device_cnt_ = torch::cuda::device_count();
639+
#endif
640+
658641
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
659642
ArtifactFilename(), device_, &model_path_, Kind(), &torch_model_));
660643

@@ -668,12 +651,17 @@ ModelInstanceState::ModelInstanceState(
668651
// timestamps to prevent timestamp skewing. However, in the future, any
669652
// modifications to the CUDA stream synchronization logic should be handled
670653
// with caution.
671-
for (int i = 0; i < torch::cuda::device_count(); i++) {
654+
for (int i = 0; i < device_cnt_; i++) {
672655
cudaStream_t stream;
673656
THROW_IF_BACKEND_INSTANCE_ERROR(
674657
CreateCudaStream(i, 0 /* cuda_stream_priority */, &stream));
675658
stream_vec_.push_back(stream);
676659
}
660+
if (!stream_vec_.empty()) {
661+
// Create CUDA events on the first device that will be used for collecting
662+
// inputs/outputs.
663+
CreateCudaEvents(0);
664+
}
677665
#endif
678666
}
679667

@@ -733,8 +721,8 @@ void
733721
ModelInstanceState::ClearCache()
734722
{
735723
#ifdef TRITON_ENABLE_GPU
736-
if (device_.is_cuda() || ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) &&
737-
(torch::cuda::device_count() > 0))) {
724+
if (device_.is_cuda() ||
725+
((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0))) {
738726
c10::cuda::CUDACachingAllocator::emptyCache();
739727
}
740728
#endif // TRITON_ENABLE_GPU
@@ -1237,27 +1225,21 @@ ModelInstanceState::ProcessRequests(
12371225
std::vector<torch::jit::IValue> input_tensors;
12381226
bool cuda_copy = false;
12391227
std::unique_ptr<BackendInputCollector> collector;
1240-
std::mutex timestamp_mu;
12411228

1242-
uint64_t compute_input_start = 0;
1243-
std::tuple<uint64_t*, std::mutex*> compute_input_cb_data(
1244-
&compute_input_start, &timestamp_mu);
1245-
1246-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1229+
// For 'KIND_MODEL', it's fine to use CUDA events to calculate the compute
1230+
// input duration since only one stream will be used for input collection.
1231+
// However, for the compute infer duration, multiple streams will be involved,
1232+
// so we need to launch a CUDA callback function for timestamp capturing, as
1233+
// demonstrated in the 'RecordBackendTimestamp' function.
1234+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1235+
((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0))) {
12471236
#ifdef TRITON_ENABLE_GPU
12481237
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
12491238
responses, request_count, all_response_failed,
12501239
ConvertCUDAStatusToTritonError(
1251-
cudaEventRecord(compute_input_start_event_, stream_),
1240+
cudaEventRecord(
1241+
compute_input_start_event_, GetCudaStreamByInstanceKind()),
12521242
TRITONSERVER_ERROR_INTERNAL, "Failed to record the event."));
1253-
#endif
1254-
} else if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1255-
#ifdef TRITON_ENABLE_GPU
1256-
for (const auto& stream : stream_vec_) {
1257-
cudaLaunchHostFunc(
1258-
stream, CaptureFirstTimestampCallback,
1259-
reinterpret_cast<void*>(&compute_input_cb_data));
1260-
}
12611243
#endif
12621244
}
12631245

@@ -1283,16 +1265,28 @@ ModelInstanceState::ProcessRequests(
12831265

12841266
std::vector<torch::jit::IValue> output_tensors;
12851267
uint64_t compute_start_ns = 0;
1286-
uint64_t compute_infer_start = 0;
1268+
std::atomic<uint64_t> compute_infer_start = 0;
1269+
1270+
// Record 'compute_infer_start_event_' for 'KIND_MODEL' to calculate the
1271+
// compute input duration. The compute infer start timestamp will be recorded
1272+
// in the 'RecordBackendTimestamp' function.
1273+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) {
1274+
#ifdef TRITON_ENABLE_GPU
1275+
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
1276+
responses, request_count, all_response_failed,
1277+
ConvertCUDAStatusToTritonError(
1278+
cudaEventRecord(
1279+
compute_infer_start_event_, GetCudaStreamByInstanceKind()),
1280+
TRITONSERVER_ERROR_INTERNAL, "Failed to record the event."));
1281+
#endif
1282+
}
12871283

1288-
std::tuple<uint64_t*, std::mutex*> compute_infer_cb_data(
1289-
&compute_infer_start, &timestamp_mu);
12901284
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
12911285
responses, request_count, all_response_failed,
12921286
RecordBackendTimestamp(
12931287
&compute_start_ns,
12941288
reinterpret_cast<void*>(&compute_infer_start_event_),
1295-
reinterpret_cast<void*>(&compute_infer_cb_data)));
1289+
reinterpret_cast<void*>(&compute_infer_start)));
12961290

12971291
// Run...
12981292
if (!all_response_failed) {
@@ -1324,16 +1318,14 @@ ModelInstanceState::ProcessRequests(
13241318
}
13251319

13261320
uint64_t compute_end_ns = 0;
1327-
uint64_t compute_output_start = 0;
1328-
std::tuple<uint64_t*, std::mutex*> compute_output_cb_data(
1329-
&compute_output_start, &timestamp_mu);
1321+
std::atomic<uint64_t> compute_output_start = 0;
13301322

13311323
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
13321324
responses, request_count, all_response_failed,
13331325
RecordBackendTimestamp(
13341326
&compute_end_ns,
13351327
reinterpret_cast<void*>(&compute_output_start_event_),
1336-
reinterpret_cast<void*>(&compute_output_cb_data)));
1328+
reinterpret_cast<void*>(&compute_output_start)));
13371329

13381330
#ifdef TRITON_ENABLE_GPU
13391331
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
@@ -1373,35 +1365,25 @@ ModelInstanceState::ProcessRequests(
13731365
// synchronized the stream in the ReadOutputTensors function.
13741366
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
13751367
#ifdef TRITON_ENABLE_GPU
1376-
// [FIXME] in the case of cudaEventElapsedTime failure, should handle
1377-
// stats reporting more gracefully as the durations are inaccurate
1378-
float compute_input_duration = 0;
1379-
float compute_infer_duration = 0;
1380-
LOG_IF_ERROR(
1381-
ConvertCUDAStatusToTritonError(
1382-
cudaEventElapsedTime(
1383-
&compute_input_duration, compute_input_start_event_,
1384-
compute_infer_start_event_),
1385-
TRITONSERVER_ERROR_INTERNAL, "Failed to capture elapsed time"),
1386-
"Failed to capture elapsed time");
1387-
1388-
LOG_IF_ERROR(
1389-
ConvertCUDAStatusToTritonError(
1390-
cudaEventElapsedTime(
1391-
&compute_infer_duration, compute_infer_start_event_,
1392-
compute_output_start_event_),
1393-
TRITONSERVER_ERROR_INTERNAL, "Failed to capture elapsed time"),
1394-
"Failed to capture elapsed time");
1368+
float compute_input_duration = GetCudaEventElapsedTime(
1369+
compute_input_start_event_, compute_infer_start_event_);
1370+
float compute_infer_duration = GetCudaEventElapsedTime(
1371+
compute_infer_start_event_, compute_output_start_event_);
13951372

13961373
compute_start_ns = exec_start_ns + (compute_input_duration * 1e6);
13971374
compute_end_ns = compute_start_ns + (compute_infer_duration * 1e6);
13981375
#endif
1399-
} else if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1400-
uint64_t compute_input_duration = compute_infer_start - compute_input_start;
1376+
} else if (
1377+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) {
1378+
#ifdef TRITON_ENABLE_GPU
1379+
float compute_input_duration = GetCudaEventElapsedTime(
1380+
compute_input_start_event_, compute_infer_start_event_);
14011381
uint64_t compute_infer_duration =
14021382
compute_output_start - compute_infer_start;
1403-
compute_start_ns = exec_start_ns + compute_input_duration;
1383+
1384+
compute_start_ns = exec_start_ns + (compute_input_duration * 1e6);
14041385
compute_end_ns = compute_start_ns + compute_infer_duration;
1386+
#endif
14051387
}
14061388

14071389
// Report statistics for each request.
@@ -1473,7 +1455,7 @@ ModelInstanceState::Execute(
14731455
bool is_device_gpu =
14741456
(device_.is_cuda() ||
14751457
((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) &&
1476-
(torch::cuda::device_count() > 0)));
1458+
(device_cnt_ > 0)));
14771459
if (std::get<1>(model_state_->EnabledNvfuserPair()) && is_device_gpu) {
14781460
torch::jit::overrideCanFuseOnCPU(false);
14791461
torch::jit::overrideCanFuseOnGPU(false);
@@ -2030,8 +2012,8 @@ ModelInstanceState::SetInputTensors(
20302012
ConvertDataTypeToTorchType(batch_input.DataType());
20312013
torch::TensorOptions options{torch_dtype.second};
20322014
auto updated_options = (dst_memory_type == TRITONSERVER_MEMORY_GPU)
2033-
? options.device(torch::kCUDA, device_.index())
2034-
: options.device(torch::kCPU);
2015+
? options.device(torch::kCUDA, device_.index())
2016+
: options.device(torch::kCPU);
20352017

20362018
torch::Tensor input_tensor = torch::from_blob(
20372019
const_cast<char*>(dst_buffer), shape, updated_options);
@@ -2195,17 +2177,17 @@ ModelInstanceState::RecordBackendTimestamp(
21952177
uint64_t* timestamp, void* cuda_event, void* timestamp_cb_data)
21962178
{
21972179
// For the 'KIND_GPU' instance group, we use a CUDA event to record the
2198-
// timestamp. For the 'KIND_MODEL' instance group, it is complicated to
2199-
// calculate the elapsed time between two cuda events from different devices,
2200-
// so we launch a CUDA callback function to record the timestamp.
2180+
// timestamp. For the 'KIND_MODEL' instance group, launch a CUDA callback
2181+
// function to record the timestamp for multiple streams.
22012182
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
22022183
#ifdef TRITON_ENABLE_GPU
22032184
cudaEvent_t* lcuda_event = reinterpret_cast<cudaEvent_t*>(cuda_event);
22042185
RETURN_IF_ERROR(ConvertCUDAStatusToTritonError(
22052186
cudaEventRecord(*lcuda_event, stream_), TRITONSERVER_ERROR_INTERNAL,
22062187
"Failed to record the event."));
22072188
#endif
2208-
} else if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
2189+
} else if (
2190+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) {
22092191
#ifdef TRITON_ENABLE_GPU
22102192
for (const auto& stream : stream_vec_) {
22112193
cudaLaunchHostFunc(
@@ -2219,17 +2201,23 @@ ModelInstanceState::RecordBackendTimestamp(
22192201
}
22202202

22212203
void
2222-
ModelInstanceState::SetCurrentCudaStream(
2223-
const cudaStream_t& stream, const int& device_id)
2204+
ModelInstanceState::CreateCudaEvents(const int32_t& device_id)
22242205
{
22252206
#ifdef TRITON_ENABLE_GPU
2226-
at::cuda::CUDAStream torch_stream =
2227-
at::cuda::getStreamFromExternal(stream, device_id);
2228-
// This function replaces the default stream with the stream we created. It
2229-
// is not necessary to change the current device to the desired device when
2230-
// replacing the default stream for that device. See the documentation here:
2231-
// https://pytorch.org/cppdocs/api/function_namespacec10_1_1cuda_1a6ed50cc0fc16cc7014d9c2f4c3bd098d.html
2232-
at::cuda::setCurrentCUDAStream(torch_stream);
2207+
// Need to set the CUDA context so that the context that events are
2208+
// created on match with contexts that events are recorded with.
2209+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
2210+
cudaSetDevice(device_id), TRITONSERVER_ERROR_INTERNAL,
2211+
"Failed to set the device"));
2212+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
2213+
cudaEventCreate(&compute_input_start_event_), TRITONSERVER_ERROR_INTERNAL,
2214+
"Failed to create cuda event"));
2215+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
2216+
cudaEventCreate(&compute_infer_start_event_), TRITONSERVER_ERROR_INTERNAL,
2217+
"Failed to create cuda event"));
2218+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
2219+
cudaEventCreate(&compute_output_start_event_),
2220+
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
22332221
#endif
22342222
}
22352223

@@ -2248,6 +2236,36 @@ ModelInstanceState::GetCudaStreamByInstanceKind()
22482236
return nullptr;
22492237
}
22502238

2239+
void
2240+
ModelInstanceState::SetCurrentCudaStream(
2241+
const cudaStream_t& stream, const int& device_id)
2242+
{
2243+
#ifdef TRITON_ENABLE_GPU
2244+
at::cuda::CUDAStream torch_stream =
2245+
at::cuda::getStreamFromExternal(stream, device_id);
2246+
// This function replaces the default stream with the stream we created. It
2247+
// is not necessary to change the current device to the desired device when
2248+
// replacing the default stream for that device. See the documentation here:
2249+
// https://pytorch.org/cppdocs/api/function_namespacec10_1_1cuda_1a6ed50cc0fc16cc7014d9c2f4c3bd098d.html
2250+
at::cuda::setCurrentCUDAStream(torch_stream);
2251+
#endif
2252+
}
2253+
2254+
float
2255+
ModelInstanceState::GetCudaEventElapsedTime(
2256+
const cudaEvent_t& start_event, const cudaEvent_t& end_event)
2257+
{
2258+
// [FIXME] in the case of cudaEventElapsedTime failure, should handle
2259+
// stats reporting more gracefully as the durations are inaccurate
2260+
float duration = 0;
2261+
LOG_IF_ERROR(
2262+
ConvertCUDAStatusToTritonError(
2263+
cudaEventElapsedTime(&duration, start_event, end_event),
2264+
TRITONSERVER_ERROR_INTERNAL, "Failed to capture elapsed time"),
2265+
"Failed to capture elapsed time");
2266+
return duration;
2267+
}
2268+
22512269
/////////////
22522270

22532271
extern "C" {

0 commit comments

Comments
 (0)