Skip to content

Commit 5069368

Browse files
author
Anurag Dixit
committed
Added test case for DLA device serialization
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 29d1a9b commit 5069368

File tree

3 files changed

+117
-5
lines changed

3 files changed

+117
-5
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ namespace trtorch {
1111
namespace core {
1212
namespace runtime {
1313

14-
const std::string empty_string = std::string();
15-
1614
std::string slugify(std::string s) {
1715
std::replace(s.begin(), s.end(), '.', '_');
1816
return s;
@@ -24,7 +22,7 @@ TRTEngine::TRTEngine(std::string serialized_engine)
2422
util::logging::get_logger().get_reportable_severity(),
2523
util::logging::get_logger().get_is_colored_output_on()) {
2624
std::string _name = "deserialized_trt";
27-
new (this) TRTEngine(_name, serialized_engine, empty_string);
25+
new (this) TRTEngine(_name, serialized_engine, std::string());
2826
}
2927

3028
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
@@ -42,7 +40,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
4240
TRTEngine::TRTEngine(
4341
std::string mod_name,
4442
std::string serialized_engine,
45-
std::string serialized_device_info = empty_string)
43+
std::string serialized_device_info = std::string())
4644
: logger(
4745
std::string("[") + mod_name + std::string("_engine] - "),
4846
util::logging::get_logger().get_reportable_severity(),

tests/modules/BUILD

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ test_suite(
2828
":test_modules_as_engines",
2929
":test_compiled_modules",
3030
":test_multiple_registered_engines",
31-
":test_serialization"
31+
":test_serialization",
32+
":test_dla_serialization"
3233
]
3334
)
3435

@@ -43,6 +44,17 @@ cc_test(
4344
]
4445
)
4546

47+
cc_test(
48+
name = "test_dla_serialization",
49+
srcs = ["test_dla_serialization.cpp"],
50+
deps = [
51+
":module_test",
52+
],
53+
data = [
54+
":jit_models"
55+
]
56+
)
57+
4658
cc_test(
4759
name = "test_multiple_registered_engines",
4860
srcs = ["test_multiple_registered_engines.cpp"],
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#include "module_test.h"
2+
3+
std::vector<trtorch::CompileSpec::InputRange> toInputRangesDynamic(std::vector<std::vector<int64_t>> opts) {
4+
std::vector<trtorch::CompileSpec::InputRange> a;
5+
6+
for (auto opt : opts) {
7+
std::vector<int64_t> min_range(opt);
8+
std::vector<int64_t> max_range(opt);
9+
10+
min_range[3] = ceil(opt[3] / 2.0);
11+
max_range[3] = 2 * opt[3];
12+
min_range[2] = ceil(opt[2] / 2.0);
13+
max_range[2] = 2 * opt[2];
14+
15+
a.push_back(trtorch::CompileSpec::InputRange(min_range, opt, max_range));
16+
}
17+
18+
return std::move(a);
19+
}
20+
21+
TEST_P(ModuleTests, SerializedModuleIsStillCorrect) {
22+
trtorch::set_device(0);
23+
std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
24+
std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
25+
for (auto in_shape : input_shapes) {
26+
auto in = at::randint(5, in_shape, {at::kCUDA}).to(torch::kF16);
27+
post_serialized_inputs_ivalues.push_back(in.clone());
28+
pre_serialized_inputs_ivalues.push_back(in.clone());
29+
}
30+
31+
auto compile_spec = trtorch::CompileSpec(toInputRangesDynamic(input_shapes));
32+
compile_spec.op_precision = torch::kF16;
33+
compile_spec.device.device_type = trtorch::CompileSpec::Device::DeviceType::kDLA;
34+
compile_spec.device.gpu_id = 0;
35+
compile_spec.device.dla_core = 1;
36+
compile_spec.device.allow_gpu_fallback = true;
37+
compile_spec.workspace_size = 1 << 28;
38+
39+
auto pre_serialized_mod = trtorch::CompileGraph(mod, compile_spec);
40+
torch::jit::IValue pre_serialized_results_ivalues =
41+
trtorch::tests::util::RunModuleForward(pre_serialized_mod, pre_serialized_inputs_ivalues);
42+
std::vector<at::Tensor> pre_serialized_results;
43+
pre_serialized_results.push_back(pre_serialized_results_ivalues.toTensor());
44+
45+
pre_serialized_mod.save("test_serialization_mod.ts");
46+
auto post_serialized_mod = torch::jit::load("test_serialization_mod.ts");
47+
48+
torch::jit::IValue post_serialized_results_ivalues =
49+
trtorch::tests::util::RunModuleForward(post_serialized_mod, post_serialized_inputs_ivalues);
50+
std::vector<at::Tensor> post_serialized_results;
51+
post_serialized_results.push_back(post_serialized_results_ivalues.toTensor());
52+
53+
for (size_t i = 0; i < pre_serialized_results.size(); i++) {
54+
ASSERT_TRUE(trtorch::tests::util::almostEqual(
55+
post_serialized_results[i], pre_serialized_results[i].reshape_as(post_serialized_results[i]), 2e-5));
56+
}
57+
}
58+
59+
TEST_P(ModuleTests, SerializedDynamicModuleIsStillCorrect) {
60+
trtorch::set_device(0);
61+
std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
62+
std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
63+
for (auto in_shape : input_shapes) {
64+
auto in = at::randint(5, in_shape, {at::kCUDA}).to(torch::kF16);
65+
post_serialized_inputs_ivalues.push_back(in.clone());
66+
pre_serialized_inputs_ivalues.push_back(in.clone());
67+
}
68+
69+
auto compile_spec = trtorch::CompileSpec(toInputRangesDynamic(input_shapes));
70+
compile_spec.op_precision = torch::kF16;
71+
compile_spec.device.device_type = trtorch::CompileSpec::Device::DeviceType::kDLA;
72+
compile_spec.device.gpu_id = 0;
73+
compile_spec.device.dla_core = 1;
74+
compile_spec.device.allow_gpu_fallback = true;
75+
compile_spec.workspace_size = 1 << 28;
76+
77+
auto pre_serialized_mod = trtorch::CompileGraph(mod, compile_spec);
78+
torch::jit::IValue pre_serialized_results_ivalues =
79+
trtorch::tests::util::RunModuleForward(pre_serialized_mod, pre_serialized_inputs_ivalues);
80+
std::vector<at::Tensor> pre_serialized_results;
81+
pre_serialized_results.push_back(pre_serialized_results_ivalues.toTensor());
82+
83+
pre_serialized_mod.save("test_serialization_mod.ts");
84+
auto post_serialized_mod = torch::jit::load("test_serialization_mod.ts");
85+
86+
torch::jit::IValue post_serialized_results_ivalues =
87+
trtorch::tests::util::RunModuleForward(post_serialized_mod, post_serialized_inputs_ivalues);
88+
std::vector<at::Tensor> post_serialized_results;
89+
post_serialized_results.push_back(post_serialized_results_ivalues.toTensor());
90+
91+
for (size_t i = 0; i < pre_serialized_results.size(); i++) {
92+
ASSERT_TRUE(trtorch::tests::util::almostEqual(
93+
post_serialized_results[i], pre_serialized_results[i].reshape_as(post_serialized_results[i]), 2e-5));
94+
}
95+
}
96+
97+
INSTANTIATE_TEST_SUITE_P(
98+
CompiledModuleForwardIsCloseSuite,
99+
ModuleTests,
100+
testing::Values(
101+
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}}),
102+
PathAndInSize({"tests/modules/pooling_traced.jit.pt", {{1, 3, 10, 10}}})));

0 commit comments

Comments
 (0)