Skip to content

Commit 2623220

Browse files
committed
fix(//tests/cpp): Fix the BERT C++ test
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 52f10cf commit 2623220

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

tests/cpp/test_compiled_modules.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ TEST_P(CppAPITests, CompiledModuleIsClose) {
55
std::vector<torch::jit::IValue> trt_inputs_ivalues;
66
std::vector<torch_tensorrt::Input> shapes;
77
for (uint64_t i = 0; i < input_shapes.size(); i++) {
8-
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
8+
auto in = at::randn(input_shapes[i], {at::kCUDA}).to(input_types[i]);
9+
if (input_types[i] == at::kInt || input_types[i] == at::kLong) {
10+
auto in = at::randint(0, 2, input_shapes[i], {at::kCUDA}).to(input_types[i]);
11+
}
12+
913
jit_inputs_ivalues.push_back(in.clone());
1014
trt_inputs_ivalues.push_back(in.clone());
1115
auto in_spec = torch_tensorrt::Input(input_shapes[i]);

tests/py/ts/models/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_efficientnet_b0(self):
9393
)
9494

9595
def test_bert_base_uncased(self):
96-
self.model = cm.BertModule().cuda()
96+
self.model = cm.BertModule()
9797
self.input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda")
9898

9999
compile_spec = {
@@ -116,7 +116,7 @@ def test_bert_base_uncased(self):
116116
"enabled_precisions": {torch.float},
117117
"truncate_long_and_double": True,
118118
}
119-
with torchtrt.logging.errors():
119+
with torchtrt.logging.debug():
120120
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
121121

122122
model_outputs = self.model(self.input, self.input)

0 commit comments

Comments
 (0)