Skip to content

Commit 157920e

Browse files
committed
For 'KIND_MODEL', use cuda events for compute_input_duration and use callback for compute_infer_duration
1 parent b1bd8af commit 157920e

File tree

1 file changed

+116
-105
lines changed

1 file changed

+116
-105
lines changed

src/libtorch.cc

Lines changed: 116 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -64,29 +64,10 @@ namespace triton { namespace backend { namespace pytorch {
6464
namespace {
6565

6666
#ifdef TRITON_ENABLE_GPU
67-
void CUDART_CB
68-
CaptureFirstTimestampCallback(void* data)
69-
{
70-
auto* tuple = reinterpret_cast<std::tuple<uint64_t*, std::mutex*>*>(data);
71-
72-
uint64_t* timestamp = std::get<0>(*tuple);
73-
std::mutex* mu = std::get<1>(*tuple);
74-
75-
std::lock_guard<std::mutex> lock(*mu);
76-
if (*timestamp == 0) {
77-
SET_TIMESTAMP(*timestamp);
78-
}
79-
}
80-
8167
void CUDART_CB
8268
CaptureLastTimestampCallback(void* data)
8369
{
84-
auto* tuple = reinterpret_cast<std::tuple<uint64_t*, std::mutex*>*>(data);
85-
86-
uint64_t* timestamp = std::get<0>(*tuple);
87-
std::mutex* mu = std::get<1>(*tuple);
88-
89-
std::lock_guard<std::mutex> lock(*mu);
70+
auto* timestamp = reinterpret_cast<std::atomic<uint64_t>*>(data);
9071
SET_TIMESTAMP(*timestamp);
9172
}
9273
#endif
@@ -571,6 +552,9 @@ class ModelInstanceState : public BackendModelInstance {
571552
NamingConvention* naming_convention,
572553
const std::vector<std::string>& allowed_io);
573554

555+
// Create CUDA events for statistics collection.
556+
void CreateCudaEvents(const int32_t& device_id);
557+
574558
// Get the appropriate CUDA stream for input and output handling based on the
575559
// instance group type.
576560
cudaStream_t GetCudaStreamByInstanceKind();
@@ -580,6 +564,10 @@ class ModelInstanceState : public BackendModelInstance {
580564
void SetCurrentCudaStream(
581565
const cudaStream_t& stream, const int32_t& device_id);
582566

567+
// Get the elapsed time between two CUDA events.
568+
float GetCudaEventElapsedTime(
569+
const cudaEvent_t& start_event, const cudaEvent_t& end_event);
570+
583571
ModelState* model_state_;
584572

585573
// The full path to the TorchScript model file.
@@ -610,6 +598,9 @@ class ModelInstanceState : public BackendModelInstance {
610598

611599
// Store the cuda streams created for the 'KIND_MODEL' instance group.
612600
std::vector<cudaStream_t> stream_vec_;
601+
602+
// The number of available devices.
603+
int device_cnt_;
613604
};
614605

615606
TRITONSERVER_Error*
@@ -633,47 +624,37 @@ ModelInstanceState::Create(
633624
ModelInstanceState::ModelInstanceState(
634625
ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance)
635626
: BackendModelInstance(model_state, triton_model_instance),
636-
model_state_(model_state), device_(torch::kCPU), is_dict_input_(false)
627+
model_state_(model_state), device_(torch::kCPU), is_dict_input_(false),
628+
device_cnt_(0)
637629
{
638630
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
639631
#ifdef TRITON_ENABLE_GPU
640632
device_ = torch::Device(torch::kCUDA, DeviceId());
641-
// Need to set the CUDA context so that the context that events are
642-
// created on match with contexts that events are recorded with.
643-
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
644-
cudaSetDevice(DeviceId()), TRITONSERVER_ERROR_INTERNAL,
645-
"Failed to set the device"));
646-
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
647-
cudaEventCreate(&compute_input_start_event_),
648-
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
649-
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
650-
cudaEventCreate(&compute_infer_start_event_),
651-
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
652-
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
653-
cudaEventCreate(&compute_output_start_event_),
654-
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
633+
CreateCudaEvents(DeviceId());
655634
#endif
656635
}
657636

637+
#ifdef TRITON_ENABLE_GPU
638+
device_cnt_ = torch::cuda::device_count();
639+
#endif
640+
658641
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
659642
ArtifactFilename(), device_, &model_path_, Kind(), &torch_model_));
660643

661644
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
662645
#ifdef TRITON_ENABLE_GPU
663-
// Since we cannot determine the exact devices used by the model, we create
664-
// a CUDA stream for every available device to ensure proper synchronization
665-
// of CUDA streams. This approach may have implications when a timestamp is
666-
// captured on a device that is not used by the model. Currently, this issue
667-
// is addressed by synchronizing the CUDA streams before recording
668-
// timestamps to prevent timestamp skewing. However, in the future, any
669-
// modifications to the CUDA stream synchronization logic should be handled
670-
// with caution.
671-
for (int i = 0; i < torch::cuda::device_count(); i++) {
646+
// Create a CUDA stream for every availble device.
647+
for (int i = 0; i < device_cnt_; i++) {
672648
cudaStream_t stream;
673649
THROW_IF_BACKEND_INSTANCE_ERROR(
674650
CreateCudaStream(i, 0 /* cuda_stream_priority */, &stream));
675651
stream_vec_.push_back(stream);
676652
}
653+
if (!stream_vec_.empty()) {
654+
// Create CUDA events on the first device that will be used for collecting
655+
// inputs/outputs.
656+
CreateCudaEvents(0);
657+
}
677658
#endif
678659
}
679660

@@ -733,8 +714,8 @@ void
733714
ModelInstanceState::ClearCache()
734715
{
735716
#ifdef TRITON_ENABLE_GPU
736-
if (device_.is_cuda() || ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) &&
737-
(torch::cuda::device_count() > 0))) {
717+
if (device_.is_cuda() ||
718+
((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0))) {
738719
c10::cuda::CUDACachingAllocator::emptyCache();
739720
}
740721
#endif // TRITON_ENABLE_GPU
@@ -1237,27 +1218,21 @@ ModelInstanceState::ProcessRequests(
12371218
std::vector<torch::jit::IValue> input_tensors;
12381219
bool cuda_copy = false;
12391220
std::unique_ptr<BackendInputCollector> collector;
1240-
std::mutex timestamp_mu;
12411221

1242-
uint64_t compute_input_start = 0;
1243-
std::tuple<uint64_t*, std::mutex*> compute_input_cb_data(
1244-
&compute_input_start, &timestamp_mu);
1245-
1246-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1222+
// For 'KIND_MODEL', it's fine to use CUDA events to calculate the compute
1223+
// input duration since only one stream will be used for input collection.
1224+
// However, for the compute infer duration, multiple streams will be involved,
1225+
// so we need to launch a CUDA callback function for timestamp capturing, as
1226+
// demonstrated in the 'RecordBackendTimestamp' function.
1227+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
1228+
((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0))) {
12471229
#ifdef TRITON_ENABLE_GPU
12481230
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
12491231
responses, request_count, all_response_failed,
12501232
ConvertCUDAStatusToTritonError(
1251-
cudaEventRecord(compute_input_start_event_, stream_),
1233+
cudaEventRecord(
1234+
compute_input_start_event_, GetCudaStreamByInstanceKind()),
12521235
TRITONSERVER_ERROR_INTERNAL, "Failed to record the event."));
1253-
#endif
1254-
} else if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1255-
#ifdef TRITON_ENABLE_GPU
1256-
for (const auto& stream : stream_vec_) {
1257-
cudaLaunchHostFunc(
1258-
stream, CaptureFirstTimestampCallback,
1259-
reinterpret_cast<void*>(&compute_input_cb_data));
1260-
}
12611236
#endif
12621237
}
12631238

@@ -1283,16 +1258,28 @@ ModelInstanceState::ProcessRequests(
12831258

12841259
std::vector<torch::jit::IValue> output_tensors;
12851260
uint64_t compute_start_ns = 0;
1286-
uint64_t compute_infer_start = 0;
1261+
std::atomic<uint64_t> compute_infer_start = 0;
1262+
1263+
// Record 'compute_infer_start_event_' for 'KIND_MODEL' to calculate the
1264+
// compute input duration. The compute infer start timestamp will be recorded
1265+
// in the 'RecordBackendTimestamp' function.
1266+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) {
1267+
#ifdef TRITON_ENABLE_GPU
1268+
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
1269+
responses, request_count, all_response_failed,
1270+
ConvertCUDAStatusToTritonError(
1271+
cudaEventRecord(
1272+
compute_infer_start_event_, GetCudaStreamByInstanceKind()),
1273+
TRITONSERVER_ERROR_INTERNAL, "Failed to record the event."));
1274+
#endif
1275+
}
12871276

1288-
std::tuple<uint64_t*, std::mutex*> compute_infer_cb_data(
1289-
&compute_infer_start, &timestamp_mu);
12901277
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
12911278
responses, request_count, all_response_failed,
12921279
RecordBackendTimestamp(
12931280
&compute_start_ns,
12941281
reinterpret_cast<void*>(&compute_infer_start_event_),
1295-
reinterpret_cast<void*>(&compute_infer_cb_data)));
1282+
reinterpret_cast<void*>(&compute_infer_start)));
12961283

12971284
// Run...
12981285
if (!all_response_failed) {
@@ -1324,16 +1311,14 @@ ModelInstanceState::ProcessRequests(
13241311
}
13251312

13261313
uint64_t compute_end_ns = 0;
1327-
uint64_t compute_output_start = 0;
1328-
std::tuple<uint64_t*, std::mutex*> compute_output_cb_data(
1329-
&compute_output_start, &timestamp_mu);
1314+
std::atomic<uint64_t> compute_output_start = 0;
13301315

13311316
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
13321317
responses, request_count, all_response_failed,
13331318
RecordBackendTimestamp(
13341319
&compute_end_ns,
13351320
reinterpret_cast<void*>(&compute_output_start_event_),
1336-
reinterpret_cast<void*>(&compute_output_cb_data)));
1321+
reinterpret_cast<void*>(&compute_output_start)));
13371322

13381323
#ifdef TRITON_ENABLE_GPU
13391324
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
@@ -1373,35 +1358,25 @@ ModelInstanceState::ProcessRequests(
13731358
// synchronized the stream in the ReadOutputTensors function.
13741359
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
13751360
#ifdef TRITON_ENABLE_GPU
1376-
// [FIXME] in the case of cudaEventElapsedTime failure, should handle
1377-
// stats reporting more gracefully as the durations are inaccurate
1378-
float compute_input_duration = 0;
1379-
float compute_infer_duration = 0;
1380-
LOG_IF_ERROR(
1381-
ConvertCUDAStatusToTritonError(
1382-
cudaEventElapsedTime(
1383-
&compute_input_duration, compute_input_start_event_,
1384-
compute_infer_start_event_),
1385-
TRITONSERVER_ERROR_INTERNAL, "Failed to capture elapsed time"),
1386-
"Failed to capture elapsed time");
1387-
1388-
LOG_IF_ERROR(
1389-
ConvertCUDAStatusToTritonError(
1390-
cudaEventElapsedTime(
1391-
&compute_infer_duration, compute_infer_start_event_,
1392-
compute_output_start_event_),
1393-
TRITONSERVER_ERROR_INTERNAL, "Failed to capture elapsed time"),
1394-
"Failed to capture elapsed time");
1361+
float compute_input_duration = GetCudaEventElapsedTime(
1362+
compute_input_start_event_, compute_infer_start_event_);
1363+
float compute_infer_duration = GetCudaEventElapsedTime(
1364+
compute_infer_start_event_, compute_output_start_event_);
13951365

13961366
compute_start_ns = exec_start_ns + (compute_input_duration * 1e6);
13971367
compute_end_ns = compute_start_ns + (compute_infer_duration * 1e6);
13981368
#endif
1399-
} else if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1400-
uint64_t compute_input_duration = compute_infer_start - compute_input_start;
1369+
} else if (
1370+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) {
1371+
#ifdef TRITON_ENABLE_GPU
1372+
float compute_input_duration = GetCudaEventElapsedTime(
1373+
compute_input_start_event_, compute_infer_start_event_);
14011374
uint64_t compute_infer_duration =
14021375
compute_output_start - compute_infer_start;
1403-
compute_start_ns = exec_start_ns + compute_input_duration;
1376+
1377+
compute_start_ns = exec_start_ns + (compute_input_duration * 1e6);
14041378
compute_end_ns = compute_start_ns + compute_infer_duration;
1379+
#endif
14051380
}
14061381

14071382
// Report statistics for each request.
@@ -1473,7 +1448,7 @@ ModelInstanceState::Execute(
14731448
bool is_device_gpu =
14741449
(device_.is_cuda() ||
14751450
((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) &&
1476-
(torch::cuda::device_count() > 0)));
1451+
(device_cnt_ > 0)));
14771452
if (std::get<1>(model_state_->EnabledNvfuserPair()) && is_device_gpu) {
14781453
torch::jit::overrideCanFuseOnCPU(false);
14791454
torch::jit::overrideCanFuseOnGPU(false);
@@ -2195,17 +2170,17 @@ ModelInstanceState::RecordBackendTimestamp(
21952170
uint64_t* timestamp, void* cuda_event, void* timestamp_cb_data)
21962171
{
21972172
// For the 'KIND_GPU' instance group, we use a CUDA event to record the
2198-
// timestamp. For the 'KIND_MODEL' instance group, it is complicated to
2199-
// calculate the elapsed time between two cuda events from different devices,
2200-
// so we launch a CUDA callback function to record the timestamp.
2173+
// timestamp. For the 'KIND_MODEL' instance group, launch a CUDA callback
2174+
// function to record the timestamp for multiple streams.
22012175
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
22022176
#ifdef TRITON_ENABLE_GPU
22032177
cudaEvent_t* lcuda_event = reinterpret_cast<cudaEvent_t*>(cuda_event);
22042178
RETURN_IF_ERROR(ConvertCUDAStatusToTritonError(
22052179
cudaEventRecord(*lcuda_event, stream_), TRITONSERVER_ERROR_INTERNAL,
22062180
"Failed to record the event."));
22072181
#endif
2208-
} else if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
2182+
} else if (
2183+
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) {
22092184
#ifdef TRITON_ENABLE_GPU
22102185
for (const auto& stream : stream_vec_) {
22112186
cudaLaunchHostFunc(
@@ -2219,17 +2194,23 @@ ModelInstanceState::RecordBackendTimestamp(
22192194
}
22202195

22212196
void
2222-
ModelInstanceState::SetCurrentCudaStream(
2223-
const cudaStream_t& stream, const int& device_id)
2197+
ModelInstanceState::CreateCudaEvents(const int32_t& device_id)
22242198
{
22252199
#ifdef TRITON_ENABLE_GPU
2226-
at::cuda::CUDAStream torch_stream =
2227-
at::cuda::getStreamFromExternal(stream, device_id);
2228-
// This function replaces the default stream with the stream we created. It
2229-
// is not necessary to change the current device to the desired device when
2230-
// replacing the default stream for that device. See the documentation here:
2231-
// https://pytorch.org/cppdocs/api/function_namespacec10_1_1cuda_1a6ed50cc0fc16cc7014d9c2f4c3bd098d.html
2232-
at::cuda::setCurrentCUDAStream(torch_stream);
2200+
// Need to set the CUDA context so that the context that events are
2201+
// created on match with contexts that events are recorded with.
2202+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
2203+
cudaSetDevice(device_id), TRITONSERVER_ERROR_INTERNAL,
2204+
"Failed to set the device"));
2205+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
2206+
cudaEventCreate(&compute_input_start_event_), TRITONSERVER_ERROR_INTERNAL,
2207+
"Failed to create cuda event"));
2208+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
2209+
cudaEventCreate(&compute_infer_start_event_), TRITONSERVER_ERROR_INTERNAL,
2210+
"Failed to create cuda event"));
2211+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
2212+
cudaEventCreate(&compute_output_start_event_),
2213+
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
22332214
#endif
22342215
}
22352216

@@ -2248,6 +2229,36 @@ ModelInstanceState::GetCudaStreamByInstanceKind()
22482229
return nullptr;
22492230
}
22502231

2232+
void
2233+
ModelInstanceState::SetCurrentCudaStream(
2234+
const cudaStream_t& stream, const int& device_id)
2235+
{
2236+
#ifdef TRITON_ENABLE_GPU
2237+
at::cuda::CUDAStream torch_stream =
2238+
at::cuda::getStreamFromExternal(stream, device_id);
2239+
// This function replaces the default stream with the stream we created. It
2240+
// is not necessary to change the current device to the desired device when
2241+
// replacing the default stream for that device. See the documentation here:
2242+
// https://pytorch.org/cppdocs/api/function_namespacec10_1_1cuda_1a6ed50cc0fc16cc7014d9c2f4c3bd098d.html
2243+
at::cuda::setCurrentCUDAStream(torch_stream);
2244+
#endif
2245+
}
2246+
2247+
float
2248+
ModelInstanceState::GetCudaEventElapsedTime(
2249+
const cudaEvent_t& start_event, const cudaEvent_t& end_event)
2250+
{
2251+
// [FIXME] in the case of cudaEventElapsedTime failure, should handle
2252+
// stats reporting more gracefully as the durations are inaccurate
2253+
float duration = 0;
2254+
LOG_IF_ERROR(
2255+
ConvertCUDAStatusToTritonError(
2256+
cudaEventElapsedTime(&duration, start_event, end_event),
2257+
TRITONSERVER_ERROR_INTERNAL, "Failed to capture elapsed time"),
2258+
"Failed to capture elapsed time");
2259+
return duration;
2260+
}
2261+
22512262
/////////////
22522263

22532264
extern "C" {

0 commit comments

Comments
 (0)