Skip to content

Commit 4778149

Browse files
committed
Capture the timestamp after synchronization
1 parent 1b1e2c6 commit 4778149

File tree

1 file changed

+12
-27
lines changed

1 file changed

+12
-27
lines changed

src/libtorch.cc

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,6 @@
6161

6262
namespace triton { namespace backend { namespace pytorch {
6363

64-
namespace {
65-
66-
#ifdef TRITON_ENABLE_GPU
67-
void CUDART_CB
68-
CaptureTimestampCallback(void* data)
69-
{
70-
auto* timestamp = reinterpret_cast<std::atomic<uint64_t>*>(data);
71-
SET_TIMESTAMP(*timestamp);
72-
}
73-
#endif
74-
75-
} // namespace
76-
7764
//
7865
// ModelState
7966
//
@@ -1304,16 +1291,22 @@ ModelInstanceState::ProcessRequests(
13041291
}
13051292
}
13061293

1294+
#ifdef TRITON_ENABLE_GPU
1295+
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1296+
// For 'KIND_MODEL', multiple streams will be involved, so we need to call
1297+
// 'cudaStreamSynchronize' before reading the output tensors.
1298+
for (auto& stream : stream_vec_) {
1299+
cudaStreamSynchronize(stream);
1300+
}
1301+
}
1302+
#endif
1303+
13071304
uint64_t compute_end_ns = 0;
1308-
std::atomic<uint64_t> compute_output_start{0};
1305+
uint64_t compute_output_start = 0;
13091306

13101307
if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) {
13111308
#ifdef TRITON_ENABLE_GPU
1312-
// For the compute infer duration, multiple streams will be involved, so we
1313-
// need to launch a CUDA callback function for timestamp capturing.
1314-
cudaLaunchHostFunc(
1315-
GetCudaStreamByInstanceKind(), CaptureTimestampCallback,
1316-
reinterpret_cast<void*>(&compute_output_start));
1309+
SET_TIMESTAMP(compute_output_start);
13171310
#endif
13181311
} else {
13191312
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(
@@ -1323,14 +1316,6 @@ ModelInstanceState::ProcessRequests(
13231316
reinterpret_cast<void*>(&compute_output_start_event_)));
13241317
}
13251318

1326-
#ifdef TRITON_ENABLE_GPU
1327-
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
1328-
for (auto& stream : stream_vec_) {
1329-
cudaStreamSynchronize(stream);
1330-
}
1331-
}
1332-
#endif
1333-
13341319
if (!all_response_failed) {
13351320
if (!invalid_index) {
13361321
RESPOND_ALL_AND_SET_TRUE_IF_ERROR(

0 commit comments

Comments
 (0)