Skip to content

Commit 6ab2ba7

Browse files
committed
Move the cudaLaunchHostFunc from RecordBackendTimestamp function
1 parent 6e7a066 commit 6ab2ba7

File tree

1 file changed

+30
-42
lines changed

1 file changed

+30
-42
lines changed

src/libtorch.cc

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ class ModelInstanceState : public BackendModelInstance {
545545
TRITONBACKEND_Request** requests, const uint32_t request_count,
546546
std::vector<TRITONBACKEND_Response*>* responses);
547547
TRITONSERVER_Error* RecordBackendTimestamp(
548-
uint64_t* timestamp, void* cuda_event, void* timestamp_cb_data);
548+
uint64_t* timestamp, void* cuda_event);
549549

550550
// Get the naming convention for inputs/outputs from the model configuration
551551
TRITONSERVER_Error* GetNamingConvention(
@@ -1228,9 +1228,6 @@ ModelInstanceState::ProcessRequests(
12281228

12291229
// For 'KIND_MODEL', it's fine to use CUDA events to calculate the compute
12301230
// input duration since only one stream will be used for input collection.
1231-
// However, for the compute infer duration, multiple streams will be involved,
1232-
// so we need to launch a CUDA callback function for timestamp capturing, as
1233-
// demonstrated in the 'RecordBackendTimestamp' function.
12341231
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
12351232
((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0))) {
12361233
#ifdef TRITON_ENABLE_GPU
@@ -1265,28 +1262,18 @@ ModelInstanceState::ProcessRequests(
12651262

12661263
std::vector<torch::jit::IValue> output_tensors;
12671264
uint64_t compute_start_ns = 0;
1268-
std::atomic<uint64_t> compute_infer_start = 0;
1269-
1270-
// Record 'compute_infer_start_event_' for 'KIND_MODEL' to calculate the
1271-
// compute input duration. The compute infer start timestamp will be recorded
1272-
// in the 'RecordBackendTimestamp' function.
1273-
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) {
1274-
#ifdef TRITON_ENABLE_GPU
1275-
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
1276-
responses, request_count, all_response_failed,
1277-
ConvertCUDAStatusToTritonError(
1278-
cudaEventRecord(
1279-
compute_infer_start_event_, GetCudaStreamByInstanceKind()),
1280-
TRITONSERVER_ERROR_INTERNAL, "Failed to record the event."));
1281-
#endif
1282-
}
1265+
uint64_t compute_infer_start = 0;
12831266

12841267
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
12851268
responses, request_count, all_response_failed,
12861269
RecordBackendTimestamp(
12871270
&compute_start_ns,
1288-
reinterpret_cast<void*>(&compute_infer_start_event_),
1289-
reinterpret_cast<void*>(&compute_infer_start)));
1271+
reinterpret_cast<void*>(&compute_infer_start_event_)));
1272+
1273+
// For 'KIND_MODEL', capture the timestamp for the compute infer duration.
1274+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) {
1275+
SET_TIMESTAMP(compute_infer_start);
1276+
}
12901277

12911278
// Run...
12921279
if (!all_response_failed) {
@@ -1320,12 +1307,21 @@ ModelInstanceState::ProcessRequests(
13201307
uint64_t compute_end_ns = 0;
13211308
std::atomic<uint64_t> compute_output_start = 0;
13221309

1323-
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
1324-
responses, request_count, all_response_failed,
1325-
RecordBackendTimestamp(
1326-
&compute_end_ns,
1327-
reinterpret_cast<void*>(&compute_output_start_event_),
1328-
reinterpret_cast<void*>(&compute_output_start)));
1310+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) {
1311+
#ifdef TRITON_ENABLE_GPU
1312+
// For the compute infer duration, multiple streams will be involved, so we
1313+
// need to launch a CUDA callback function for timestamp capturing.
1314+
cudaLaunchHostFunc(
1315+
GetCudaStreamByInstanceKind(), CaptureLastTimestampCallback,
1316+
reinterpret_cast<void*>(&compute_output_start));
1317+
#endif
1318+
} else {
1319+
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
1320+
responses, request_count, all_response_failed,
1321+
RecordBackendTimestamp(
1322+
&compute_end_ns,
1323+
reinterpret_cast<void*>(&compute_output_start_event_)));
1324+
}
13291325

13301326
#ifdef TRITON_ENABLE_GPU
13311327
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
@@ -2174,25 +2170,15 @@ ModelInstanceState::ReadOutputTensors(
21742170

21752171
TRITONSERVER_Error*
21762172
ModelInstanceState::RecordBackendTimestamp(
2177-
uint64_t* timestamp, void* cuda_event, void* timestamp_cb_data)
2173+
uint64_t* timestamp, void* cuda_event)
21782174
{
2179-
// For the 'KIND_GPU' instance group, we use a CUDA event to record the
2180-
// timestamp. For the 'KIND_MODEL' instance group, launch a CUDA callback
2181-
// function to record the timestamp for multiple streams.
2182-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
2175+
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
2176+
((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0))) {
21832177
#ifdef TRITON_ENABLE_GPU
21842178
cudaEvent_t* lcuda_event = reinterpret_cast<cudaEvent_t*>(cuda_event);
21852179
RETURN_IF_ERROR(ConvertCUDAStatusToTritonError(
2186-
cudaEventRecord(*lcuda_event, stream_), TRITONSERVER_ERROR_INTERNAL,
2187-
"Failed to record the event."));
2188-
#endif
2189-
} else if (
2190-
(Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) {
2191-
#ifdef TRITON_ENABLE_GPU
2192-
for (const auto& stream : stream_vec_) {
2193-
cudaLaunchHostFunc(
2194-
stream, CaptureLastTimestampCallback, timestamp_cb_data);
2195-
}
2180+
cudaEventRecord(*lcuda_event, GetCudaStreamByInstanceKind()),
2181+
TRITONSERVER_ERROR_INTERNAL, "Failed to record the event."));
21962182
#endif
21972183
} else {
21982184
SET_TIMESTAMP(*timestamp);
@@ -2255,6 +2241,7 @@ float
22552241
ModelInstanceState::GetCudaEventElapsedTime(
22562242
const cudaEvent_t& start_event, const cudaEvent_t& end_event)
22572243
{
2244+
#ifdef TRITON_ENABLE_GPU
22582245
// [FIXME] in the case of cudaEventElapsedTime failure, should handle
22592246
// stats reporting more gracefully as the durations are inaccurate
22602247
float duration = 0;
@@ -2264,6 +2251,7 @@ ModelInstanceState::GetCudaEventElapsedTime(
22642251
TRITONSERVER_ERROR_INTERNAL, "Failed to capture elapsed time"),
22652252
"Failed to capture elapsed time");
22662253
return duration;
2254+
#endif
22672255
}
22682256

22692257
/////////////

0 commit comments

Comments
 (0)