Skip to content

Commit 7996a10

Browse files
committed
feat(//tests): Adding BERT to the test suite
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 72c7b76 commit 7996a10

8 files changed

+69
-44
lines changed

tests/cpp/cpp_api_test.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
#include "torch/script.h"
77
#include "torch_tensorrt/torch_tensorrt.h"
88

9-
using PathAndInSize = std::tuple<std::string, std::vector<std::vector<int64_t>>, float>;
9+
using PathAndInput = std::tuple<std::string, std::vector<std::vector<int64_t>>, std::vector<c10::ScalarType>, float>;
1010

11-
class CppAPITests : public testing::TestWithParam<PathAndInSize> {
11+
class CppAPITests : public testing::TestWithParam<PathAndInput> {
1212
public:
1313
void SetUp() override {
14-
PathAndInSize params = GetParam();
14+
PathAndInput params = GetParam();
1515
std::string path = std::get<0>(params);
1616
try {
1717
// Deserialize the ScriptModule from a file using torch::jit::load().
@@ -21,7 +21,8 @@ class CppAPITests : public testing::TestWithParam<PathAndInSize> {
2121
ASSERT_TRUE(false);
2222
}
2323
input_shapes = std::get<1>(params);
24-
threshold = std::get<2>(params);
24+
input_types = std::get<2>(params);
25+
threshold = std::get<3>(params);
2526
}
2627

2728
void TearDown() {
@@ -32,5 +33,6 @@ class CppAPITests : public testing::TestWithParam<PathAndInSize> {
3233
protected:
3334
torch::jit::script::Module mod;
3435
std::vector<std::vector<int64_t>> input_shapes;
36+
std::vector<c10::ScalarType> input_types;
3537
float threshold;
3638
};

tests/cpp/test_compiled_modules.cpp

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,42 @@
33
TEST_P(CppAPITests, CompiledModuleIsClose) {
44
std::vector<torch::jit::IValue> jit_inputs_ivalues;
55
std::vector<torch::jit::IValue> trt_inputs_ivalues;
6-
for (auto in_shape : input_shapes) {
7-
auto in = at::randint(5, in_shape, {at::kCUDA});
6+
std::vector<torch_tensorrt::Input> shapes;
7+
for (uint64_t i = 0; i < input_shapes.size(); i++) {
8+
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
89
jit_inputs_ivalues.push_back(in.clone());
910
trt_inputs_ivalues.push_back(in.clone());
11+
auto in_spec = torch_tensorrt::Input(input_shapes[i]);
12+
in_spec.dtype = input_types[i];
13+
shapes.push_back(in_spec);
14+
std::cout << in_spec << std::endl;
1015
}
1116

1217
torch::jit::IValue jit_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(mod, jit_inputs_ivalues);
1318
std::vector<at::Tensor> jit_results;
14-
jit_results.push_back(jit_results_ivalues.toTensor());
19+
if (jit_results_ivalues.isTuple()) {
20+
auto tuple = jit_results_ivalues.toTuple();
21+
for (auto t : tuple->elements()) {
22+
jit_results.push_back(t.toTensor());
23+
}
24+
} else {
25+
jit_results.push_back(jit_results_ivalues.toTensor());
26+
}
27+
28+
auto spec = torch_tensorrt::ts::CompileSpec(shapes);
29+
spec.truncate_long_and_double = true;
1530

16-
auto trt_mod = torch_tensorrt::ts::compile(mod, input_shapes);
31+
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
1732
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
1833
std::vector<at::Tensor> trt_results;
19-
trt_results.push_back(trt_results_ivalues.toTensor());
34+
if (trt_results_ivalues.isTuple()) {
35+
auto tuple = trt_results_ivalues.toTuple();
36+
for (auto t : tuple->elements()) {
37+
trt_results.push_back(t.toTensor());
38+
}
39+
} else {
40+
trt_results.push_back(trt_results_ivalues.toTensor());
41+
}
2042

2143
for (size_t i = 0; i < trt_results.size(); i++) {
2244
ASSERT_TRUE(
@@ -30,13 +52,14 @@ INSTANTIATE_TEST_SUITE_P(
3052
CompiledModuleForwardIsCloseSuite,
3153
CppAPITests,
3254
testing::Values(
33-
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
34-
PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
35-
PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
36-
PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
37-
PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
38-
PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
39-
PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3}),
40-
PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-2})));
55+
PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
56+
PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
57+
PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
58+
PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
59+
PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
60+
PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
61+
PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-3}),
62+
PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}, 8e-2}),
63+
PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-2})));
4164

4265
#endif

tests/cpp/test_default_input_types.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP16WeightsFP32In) {
7878
}
7979

8080
auto in = torch_tensorrt::Input(input_shapes[0]);
81-
in.dtype = torch::kF32;
81+
in.dtype = torch::kFloat;
8282
auto spec = torch_tensorrt::ts::CompileSpec({in});
8383
spec.enabled_precisions.insert(torch_tensorrt::DataType::kHalf);
8484

@@ -116,4 +116,4 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) {
116116
INSTANTIATE_TEST_SUITE_P(
117117
CompiledModuleForwardIsCloseSuite,
118118
CppAPITests,
119-
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));
119+
testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5})));

tests/cpp/test_example_tensors.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
TEST_P(CppAPITests, InputsFromTensors) {
44
std::vector<torch::jit::IValue> jit_inputs_ivalues;
55
std::vector<torch::jit::IValue> trt_inputs_ivalues;
6-
for (auto in_shape : input_shapes) {
7-
auto in = at::randn(in_shape, {at::kCUDA});
6+
for (uint64_t i = 0; i < input_shapes.size(); i++) {
7+
auto in = at::randn(input_shapes[i], {at::kCUDA}).to(input_types[i]);
88
jit_inputs_ivalues.push_back(in.clone());
99
trt_inputs_ivalues.push_back(in.clone());
1010
}
@@ -20,4 +20,4 @@ TEST_P(CppAPITests, InputsFromTensors) {
2020
INSTANTIATE_TEST_SUITE_P(
2121
CompiledModuleForwardIsCloseSuite,
2222
CppAPITests,
23-
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));
23+
testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5})));

tests/cpp/test_modules_as_engines.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
TEST_P(CppAPITests, ModuleAsEngineIsClose) {
55
std::vector<at::Tensor> inputs;
66
std::vector<torch::jit::IValue> inputs_ivalues;
7-
for (auto in_shape : input_shapes) {
8-
inputs.push_back(at::randint(5, in_shape, {at::kCUDA}));
7+
for (uint64_t i = 0; i < input_shapes.size(); i++) {
8+
inputs.push_back(at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]));
99
inputs_ivalues.push_back(inputs[inputs.size() - 1].clone());
1010
}
1111

@@ -21,8 +21,8 @@ TEST_P(CppAPITests, ModuleAsEngineIsClose) {
2121
TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) {
2222
std::vector<at::Tensor> inputs;
2323
std::vector<torch::jit::IValue> inputs_ivalues;
24-
for (auto in_shape : input_shapes) {
25-
inputs.push_back(at::randint(5, in_shape, {at::kCUDA}));
24+
for (uint64_t i = 0; i < input_shapes.size(); i++) {
25+
inputs.push_back(at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]));
2626
inputs_ivalues.push_back(inputs[inputs.size() - 1].clone());
2727
}
2828

@@ -57,13 +57,13 @@ INSTANTIATE_TEST_SUITE_P(
5757
ModuleAsEngineForwardIsCloseSuite,
5858
CppAPITests,
5959
testing::Values(
60-
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
61-
PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
62-
PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
63-
PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
64-
PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
65-
PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
66-
PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
67-
PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-2})));
60+
PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
61+
PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
62+
PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
63+
PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
64+
PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
65+
PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
66+
PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
67+
PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-2})));
6868

6969
#endif

tests/cpp/test_multi_gpu_serde.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
TEST_P(CppAPITests, CompiledModuleIsClose) {
55
std::vector<torch::jit::IValue> jit_inputs_ivalues;
66
std::vector<torch::jit::IValue> trt_inputs_ivalues;
7-
for (auto in_shape : input_shapes) {
8-
auto in = at::randint(5, in_shape, {at::kCUDA});
7+
for (uint64_t i = 0; i < input_shapes.size(); i++) {
8+
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
99
jit_inputs_ivalues.push_back(in.clone());
1010
trt_inputs_ivalues.push_back(in.clone());
1111
}
@@ -31,4 +31,4 @@ TEST_P(CppAPITests, CompiledModuleIsClose) {
3131
INSTANTIATE_TEST_SUITE_P(
3232
CompiledModuleForwardIsCloseSuite,
3333
CppAPITests,
34-
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));
34+
testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5})));

tests/cpp/test_serialization.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ std::vector<torch_tensorrt::Input> toInputRangesDynamic(std::vector<std::vector<
2121
TEST_P(CppAPITests, SerializedModuleIsStillCorrect) {
2222
std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
2323
std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
24-
for (auto in_shape : input_shapes) {
25-
auto in = at::randint(5, in_shape, {at::kCUDA});
24+
for (uint64_t i = 0; i < input_shapes.size(); i++) {
25+
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
2626
post_serialized_inputs_ivalues.push_back(in.clone());
2727
pre_serialized_inputs_ivalues.push_back(in.clone());
2828
}
@@ -50,8 +50,8 @@ TEST_P(CppAPITests, SerializedModuleIsStillCorrect) {
5050
TEST_P(CppAPITests, SerializedDynamicModuleIsStillCorrect) {
5151
std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
5252
std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
53-
for (auto in_shape : input_shapes) {
54-
auto in = at::randint(5, in_shape, {at::kCUDA});
53+
for (uint64_t i = 0; i < input_shapes.size(); i++) {
54+
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
5555
post_serialized_inputs_ivalues.push_back(in.clone());
5656
pre_serialized_inputs_ivalues.push_back(in.clone());
5757
}
@@ -81,5 +81,5 @@ INSTANTIATE_TEST_SUITE_P(
8181
CompiledModuleForwardIsCloseSuite,
8282
CppAPITests,
8383
testing::Values(
84-
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
85-
PathAndInSize({"tests/modules/pooling_traced.jit.pt", {{1, 3, 10, 10}}, 2e-5})));
84+
PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
85+
PathAndInput({"tests/modules/pooling_traced.jit.pt", {{1, 3, 10, 10}}, {at::kFloat}, 2e-5})));

tests/modules/hub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,4 +217,4 @@ def forward(self, x):
217217
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
218218

219219
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
220-
torch.jit.save(traced_model, "bert_base_uncased_traced.jit..pt")
220+
torch.jit.save(traced_model, "bert_base_uncased_traced.jit.pt")

0 commit comments

Comments
 (0)