Skip to content

Commit 1eebc04

Browse files
committed
fix(//tests/cpp): Fix the BERT C++ test
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 52f10cf commit 1eebc04

File tree

5 files changed

+28
-23
lines changed

5 files changed

+28
-23
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,6 @@ TRTEngine::TRTEngine(
7272
set_rt_device(device_info);
7373

7474
// Set active stream to non-default stream
75-
auto current_stream = c10::cuda::getCurrentCUDAStream(device_info.id);
76-
if (current_stream == c10::cuda::getDefaultCUDAStream(device_info.id)) {
77-
active_stream = c10::cuda::getStreamFromPool(false, device_info.id);
78-
c10::cuda::setCurrentCUDAStream(active_stream);
79-
} else {
80-
active_stream = current_stream;
81-
}
82-
8375
rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger()));
8476

8577
name = slugify(mod_name);

core/runtime/TRTEngine.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ struct TRTEngine : torch::CustomClassHolder {
7070

7171
// CUDAGraph-Related Functionality
7272
at::cuda::CUDAGraph cudagraph = {};
73-
at::cuda::CUDAStream active_stream = c10::cuda::getDefaultCUDAStream();
73+
at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream();
74+
at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream();
7475
std::vector<at::Tensor> input_buffers = {};
7576
std::vector<at::Tensor> output_buffers = {};
7677
std::string shape_key;

core/runtime/execute_engine.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "ATen/cuda/CUDAEvent.h"
12
#include "c10/cuda/CUDAGuard.h"
23
#include "c10/cuda/CUDAStream.h"
34

@@ -70,7 +71,7 @@ bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_
7071
new_shape_key_ss << "(";
7172
auto sizes = input.sizes();
7273
auto rank = input.sizes().size();
73-
for (auto i = 0; i < rank; i++) {
74+
for (size_t i = 0; i < rank; i++) {
7475
new_shape_key_ss << sizes[i];
7576
// For all but the final dimension in the shape key, add comma separator
7677
if (i < rank - 1) {
@@ -142,13 +143,13 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
142143
select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible);
143144
set_rt_device(device);
144145

146+
compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(device.id);
145147
// Update active stream based on new device
146-
auto current_stream = c10::cuda::getCurrentCUDAStream(device.id);
147-
if (current_stream == c10::cuda::getDefaultCUDAStream(device.id)) {
148-
compiled_engine->active_stream = c10::cuda::getStreamFromPool(false, device.id);
149-
c10::cuda::setCurrentCUDAStream(compiled_engine->active_stream);
148+
if (compiled_engine->caller_stream == c10::cuda::getDefaultCUDAStream(device.id)) {
149+
compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, device.id);
150+
c10::cuda::setCurrentCUDAStream(compiled_engine->engine_stream);
150151
} else {
151-
compiled_engine->active_stream = current_stream;
152+
compiled_engine->engine_stream = compiled_engine->caller_stream;
152153
}
153154

154155
// Target device is new device
@@ -274,16 +275,23 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
274275

275276
if (!CUDAGRAPHS_MODE) {
276277
// If not in cudagraphs mode, proceed with enqueueV3 as normal
277-
compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream);
278+
at::cuda::CUDAEvent caller_exec_complete;
279+
caller_exec_complete.record(compiled_engine->caller_stream);
280+
caller_exec_complete.block(compiled_engine->engine_stream);
281+
compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream);
282+
at::cuda::CUDAEvent trt_exec_complete;
283+
trt_exec_complete.record(compiled_engine->engine_stream);
284+
trt_exec_complete.block(compiled_engine->caller_stream);
278285
} else if (need_cudagraphs_record) {
279286
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
280287

281288
// Cudagraphs cannot record on the current stream, so use an alternate
282289
c10::cuda::CUDAStream recording_stream = c10::cuda::getStreamFromPool(false, inputs[0].device().index());
283290
c10::cuda::CUDAStreamGuard guard(recording_stream);
284291

285-
compiled_engine->exec_ctx->enqueueV3(recording_stream);
286-
recording_stream.synchronize();
292+
at::cuda::CUDAEvent caller_exec_complete;
293+
caller_exec_complete.record(compiled_engine->caller_stream);
294+
caller_exec_complete.block(recording_stream);
287295

288296
compiled_engine->cudagraph.capture_begin();
289297
compiled_engine->exec_ctx->enqueueV3(recording_stream);
@@ -294,7 +302,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
294302

295303
} else {
296304
// If the cudagraph has already been recorded, copy the input buffers and replay it
297-
for (auto i = 0; i < inputs.size(); i++) {
305+
for (size_t i = 0; i < inputs.size(); i++) {
298306
compiled_engine->input_buffers[i].copy_(inputs[i], true);
299307
}
300308
compiled_engine->cudagraph.replay();
@@ -305,7 +313,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
305313
// In cudagraphs mode, the output buffers can be reused, so they must
306314
// be cloned before providing them to the user to avoid data corruption
307315
if (CUDAGRAPHS_MODE) {
308-
for (auto i = 0; i < compiled_engine->output_buffers.size(); i++) {
316+
for (size_t i = 0; i < compiled_engine->output_buffers.size(); i++) {
309317
model_outputs[i] = compiled_engine->output_buffers[i].clone();
310318
}
311319
} else {

tests/cpp/test_compiled_modules.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ TEST_P(CppAPITests, CompiledModuleIsClose) {
55
std::vector<torch::jit::IValue> trt_inputs_ivalues;
66
std::vector<torch_tensorrt::Input> shapes;
77
for (uint64_t i = 0; i < input_shapes.size(); i++) {
8-
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
8+
auto in = at::randn(input_shapes[i], {at::kCUDA}).to(input_types[i]);
9+
if (input_types[i] == at::kInt || input_types[i] == at::kLong) {
10+
auto in = at::randint(0, 2, input_shapes[i], {at::kCUDA}).to(input_types[i]);
11+
}
12+
913
jit_inputs_ivalues.push_back(in.clone());
1014
trt_inputs_ivalues.push_back(in.clone());
1115
auto in_spec = torch_tensorrt::Input(input_shapes[i]);

tests/py/ts/models/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_efficientnet_b0(self):
9393
)
9494

9595
def test_bert_base_uncased(self):
96-
self.model = cm.BertModule().cuda()
96+
self.model = cm.BertModule()
9797
self.input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda")
9898

9999
compile_spec = {
@@ -116,7 +116,7 @@ def test_bert_base_uncased(self):
116116
"enabled_precisions": {torch.float},
117117
"truncate_long_and_double": True,
118118
}
119-
with torchtrt.logging.errors():
119+
with torchtrt.logging.debug():
120120
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
121121

122122
model_outputs = self.model(self.input, self.input)

0 commit comments

Comments
 (0)