@@ -13,7 +13,7 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) {
13
13
jit_results.push_back (jit_results_ivalues.toTensor ());
14
14
auto trt_results = trtorch::tests::util::RunModuleForwardAsEngine (mod, inputs);
15
15
16
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 8e-3 ));
16
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), threshold ));
17
17
}
18
18
19
19
TEST_P (ModuleTests, ModuleToEngineToModuleIsClose) {
@@ -41,18 +41,18 @@ TEST_P(ModuleTests, ModuleToEngineToModuleIsClose) {
41
41
std::vector<at::Tensor> trt_results;
42
42
trt_results.push_back (trt_results_ivalues.toTensor ());
43
43
44
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 8e-3 ));
44
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), threshold ));
45
45
}
46
46
47
47
INSTANTIATE_TEST_SUITE_P (
48
48
ModuleAsEngineForwardIsCloseSuite,
49
49
ModuleTests,
50
50
testing::Values (
51
- PathAndInSize ({" tests/modules/resnet18_traced.jit.pt" , {{1 , 3 , 224 , 224 }}}),
52
- PathAndInSize({" tests/modules/resnet50_traced.jit.pt" , {{1 , 3 , 224 , 224 }}}),
53
- PathAndInSize({" tests/modules/mobilenet_v2_traced.jit.pt" , {{1 , 3 , 224 , 224 }}}),
54
- PathAndInSize({" tests/modules/resnet18_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}}),
55
- PathAndInSize({" tests/modules/resnet50_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}}),
56
- PathAndInSize({" tests/modules/mobilenet_v2_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}}),
57
- PathAndInSize({" tests/modules/efficientnet_b0_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}}),
58
- PathAndInSize({" tests/modules/vit_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}})));
51
+ PathAndInSize ({" tests/modules/resnet18_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
52
+ PathAndInSize({" tests/modules/resnet50_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
53
+ PathAndInSize({" tests/modules/mobilenet_v2_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
54
+ PathAndInSize({" tests/modules/resnet18_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
55
+ PathAndInSize({" tests/modules/resnet50_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
56
+ PathAndInSize({" tests/modules/mobilenet_v2_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
57
+ PathAndInSize({" tests/modules/efficientnet_b0_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
58
+ PathAndInSize({" tests/modules/vit_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 8e-3 })));
0 commit comments