1
+ #include " ATen/cuda/CUDAEvent.h"
1
2
#include " c10/cuda/CUDAGuard.h"
2
3
#include " c10/cuda/CUDAStream.h"
3
4
@@ -70,7 +71,7 @@ bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_
70
71
new_shape_key_ss << " (" ;
71
72
auto sizes = input.sizes ();
72
73
auto rank = input.sizes ().size ();
73
- for (auto i = 0 ; i < rank; i++) {
74
+ for (size_t i = 0 ; i < rank; i++) {
74
75
new_shape_key_ss << sizes[i];
75
76
// For all but the final dimension in the shape key, add comma separator
76
77
if (i < rank - 1 ) {
@@ -142,13 +143,13 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
142
143
select_rt_device (compiled_engine->device_info , curr_device, compiled_engine->hardware_compatible );
143
144
set_rt_device (device);
144
145
146
+ compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream (device.id );
145
147
// Update active stream based on new device
146
- auto current_stream = c10::cuda::getCurrentCUDAStream (device.id );
147
- if (current_stream == c10::cuda::getDefaultCUDAStream (device.id )) {
148
- compiled_engine->active_stream = c10::cuda::getStreamFromPool (false , device.id );
149
- c10::cuda::setCurrentCUDAStream (compiled_engine->active_stream );
148
+ if (compiled_engine->caller_stream == c10::cuda::getDefaultCUDAStream (device.id )) {
149
+ compiled_engine->engine_stream = c10::cuda::getStreamFromPool (false , device.id );
150
+ c10::cuda::setCurrentCUDAStream (compiled_engine->engine_stream );
150
151
} else {
151
- compiled_engine->active_stream = current_stream ;
152
+ compiled_engine->engine_stream = compiled_engine-> caller_stream ;
152
153
}
153
154
154
155
// Target device is new device
@@ -274,16 +275,23 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
274
275
275
276
if (!CUDAGRAPHS_MODE) {
276
277
// If not in cudagraphs mode, proceed with enqueueV3 as normal
277
- compiled_engine->exec_ctx ->enqueueV3 (compiled_engine->active_stream );
278
+ at::cuda::CUDAEvent caller_exec_complete;
279
+ caller_exec_complete.record (compiled_engine->caller_stream );
280
+ caller_exec_complete.block (compiled_engine->engine_stream );
281
+ compiled_engine->exec_ctx ->enqueueV3 (compiled_engine->engine_stream );
282
+ at::cuda::CUDAEvent trt_exec_complete;
283
+ trt_exec_complete.record (compiled_engine->engine_stream );
284
+ trt_exec_complete.block (compiled_engine->caller_stream );
278
285
} else if (need_cudagraphs_record) {
279
286
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
280
287
281
288
// Cudagraphs cannot record on the current stream, so use an alternate
282
289
c10::cuda::CUDAStream recording_stream = c10::cuda::getStreamFromPool (false , inputs[0 ].device ().index ());
283
290
c10::cuda::CUDAStreamGuard guard (recording_stream);
284
291
285
- compiled_engine->exec_ctx ->enqueueV3 (recording_stream);
286
- recording_stream.synchronize ();
292
+ at::cuda::CUDAEvent caller_exec_complete;
293
+ caller_exec_complete.record (compiled_engine->caller_stream );
294
+ caller_exec_complete.block (recording_stream);
287
295
288
296
compiled_engine->cudagraph .capture_begin ();
289
297
compiled_engine->exec_ctx ->enqueueV3 (recording_stream);
@@ -294,7 +302,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
294
302
295
303
} else {
296
304
// If the cudagraph has already been recorded, copy the input buffers and replay it
297
- for (auto i = 0 ; i < inputs.size (); i++) {
305
+ for (size_t i = 0 ; i < inputs.size (); i++) {
298
306
compiled_engine->input_buffers [i].copy_ (inputs[i], true );
299
307
}
300
308
compiled_engine->cudagraph .replay ();
@@ -305,7 +313,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
305
313
// In cudagraphs mode, the output buffers can be reused, so they must
306
314
// be cloned before providing them to the user to avoid data corruption
307
315
if (CUDAGRAPHS_MODE) {
308
- for (auto i = 0 ; i < compiled_engine->output_buffers .size (); i++) {
316
+ for (size_t i = 0 ; i < compiled_engine->output_buffers .size (); i++) {
309
317
model_outputs[i] = compiled_engine->output_buffers [i].clone ();
310
318
}
311
319
} else {
0 commit comments