Skip to content

Commit d207504

Browse files
authored
Merge pull request #518 from NVIDIA/vit
chore: Add efficientnet b0 and VIT to testsuite
2 parents 2a00b63 + ae18b21 commit d207504

File tree

5 files changed

+35
-24
lines changed

5 files changed

+35
-24
lines changed

tests/modules/hub.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,13 @@
6666
"model": models.detection.fasterrcnn_resnet50_fpn(pretrained=True),
6767
"path": "script"
6868
},
69-
"vit": {
69+
"efficientnet_b0": {
7070
"model": timm.create_model('efficientnet_b0', pretrained=True),
7171
"path": "script"
72+
},
73+
"vit": {
74+
"model": timm.create_model('vit_base_patch16_224', pretrained=True),
75+
"path": "script"
7276
}
7377
}
7478

tests/modules/module_test.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,22 @@
66
#include "torch/script.h"
77
#include "trtorch/trtorch.h"
88

9-
using PathAndInSize = std::pair<std::string, std::vector<std::vector<int64_t>>>;
9+
using PathAndInSize = std::tuple<std::string, std::vector<std::vector<int64_t>>, float>;
1010

1111
class ModuleTests : public testing::TestWithParam<PathAndInSize> {
1212
public:
1313
void SetUp() override {
14-
auto params = GetParam();
15-
auto path = params.first;
14+
PathAndInSize params = GetParam();
15+
std::string path = std::get<0>(params);
1616
try {
1717
// Deserialize the ScriptModule from a file using torch::jit::load().
1818
mod = torch::jit::load(path);
1919
} catch (const c10::Error& e) {
2020
std::cerr << "error loading the model\n";
2121
return;
2222
}
23-
input_shapes = params.second;
23+
input_shapes = std::get<1>(params);
24+
threshold = std::get<2>(params);
2425
}
2526

2627
void TearDown() {
@@ -31,4 +32,5 @@ class ModuleTests : public testing::TestWithParam<PathAndInSize> {
3132
protected:
3233
torch::jit::script::Module mod;
3334
std::vector<std::vector<int64_t>> input_shapes;
35+
float threshold;
3436
};

tests/modules/test_compiled_modules.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,20 @@ TEST_P(ModuleTests, CompiledModuleIsClose) {
1919
trt_results.push_back(trt_results_ivalues.toTensor());
2020

2121
for (size_t i = 0; i < trt_results.size(); i++) {
22-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), 2e-5));
22+
ASSERT_TRUE(
23+
trtorch::tests::util::almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), threshold));
2324
}
2425
}
2526

2627
INSTANTIATE_TEST_SUITE_P(
2728
CompiledModuleForwardIsCloseSuite,
2829
ModuleTests,
2930
testing::Values(
30-
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}}),
31-
PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}}),
32-
PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}}),
33-
PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}}),
34-
PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}}),
35-
PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}})));
31+
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
32+
PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
33+
PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
34+
PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
35+
PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
36+
PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
37+
PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
38+
PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3})));

tests/modules/test_modules_as_engines.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) {
1313
jit_results.push_back(jit_results_ivalues.toTensor());
1414
auto trt_results = trtorch::tests::util::RunModuleForwardAsEngine(mod, inputs);
1515

16-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
16+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), threshold));
1717
}
1818

1919
TEST_P(ModuleTests, ModuleToEngineToModuleIsClose) {
@@ -41,16 +41,18 @@ TEST_P(ModuleTests, ModuleToEngineToModuleIsClose) {
4141
std::vector<at::Tensor> trt_results;
4242
trt_results.push_back(trt_results_ivalues.toTensor());
4343

44-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
44+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), threshold));
4545
}
4646

4747
INSTANTIATE_TEST_SUITE_P(
4848
ModuleAsEngineForwardIsCloseSuite,
4949
ModuleTests,
5050
testing::Values(
51-
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}}),
52-
PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}}),
53-
PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}}),
54-
PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}}),
55-
PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}}),
56-
PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}})));
51+
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
52+
PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
53+
PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
54+
PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
55+
PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
56+
PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
57+
PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
58+
PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3})));

tests/modules/test_serialization.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ TEST_P(ModuleTests, SerializedModuleIsStillCorrect) {
4343

4444
for (size_t i = 0; i < pre_serialized_results.size(); i++) {
4545
ASSERT_TRUE(trtorch::tests::util::almostEqual(
46-
post_serialized_results[i], pre_serialized_results[i].reshape_as(post_serialized_results[i]), 2e-5));
46+
post_serialized_results[i], pre_serialized_results[i].reshape_as(post_serialized_results[i]), threshold));
4747
}
4848
}
4949

@@ -72,13 +72,13 @@ TEST_P(ModuleTests, SerializedDynamicModuleIsStillCorrect) {
7272

7373
for (size_t i = 0; i < pre_serialized_results.size(); i++) {
7474
ASSERT_TRUE(trtorch::tests::util::almostEqual(
75-
post_serialized_results[i], pre_serialized_results[i].reshape_as(post_serialized_results[i]), 2e-5));
75+
post_serialized_results[i], pre_serialized_results[i].reshape_as(post_serialized_results[i]), threshold));
7676
}
7777
}
7878

7979
INSTANTIATE_TEST_SUITE_P(
8080
CompiledModuleForwardIsCloseSuite,
8181
ModuleTests,
8282
testing::Values(
83-
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}}),
84-
PathAndInSize({"tests/modules/pooling_traced.jit.pt", {{1, 3, 10, 10}}})));
83+
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
84+
PathAndInSize({"tests/modules/pooling_traced.jit.pt", {{1, 3, 10, 10}}, 2e-5})));

0 commit comments

Comments
 (0)