Skip to content

Commit 0daa237

Browse files
committed
Test(tests/core/converters) added 5 tests for test_interpolate
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent b6942a2 commit 0daa237

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

tests/core/converters/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ converter_test(
5151
name = "test_unary"
5252
)
5353

54+
converter_test(
55+
name = "test_interpolate"
56+
)
57+
5458
test_suite(
5559
name = "test_converters",
5660
tests = [
@@ -65,6 +69,7 @@ test_suite(
6569
":test_shuffle",
6670
":test_softmax",
6771
":test_unary",
72+
":test_interpolate",
6873
]
6974
)
7075

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/ir/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenUpsampleNearest1dConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%2 : int = prim::Constant[value=10]()
11+
%3 : int[] = prim::ListConstruct(%2)
12+
%4 : None = prim::Constant()
13+
%5 : Tensor = aten::upsample_nearest1d(%0, %3, %4)
14+
return (%5))IR";
15+
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
18+
torch::jit::parseIR(graph, &*g);
19+
20+
// Input Tensor needs to be 3D for TensorRT upsample_nearest1d
21+
auto in = at::randint(1, 10, {10, 2, 2}, {at::kCUDA});
22+
23+
auto jit_in = at::clone(in);
24+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
25+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
26+
27+
auto trt_in = at::clone(in);
28+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
29+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
30+
31+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
32+
33+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
34+
}
35+
36+
TEST(Converters, ATenUpsampleNearest2dConvertsCorrectly1dOutputSize) {
37+
const auto graph = R"IR(
38+
graph(%0 : Tensor):
39+
%2 : int = prim::Constant[value=10]()
40+
%3 : int[] = prim::ListConstruct(%2, %2)
41+
%4 : None = prim::Constant()
42+
%5 : Tensor = aten::upsample_nearest2d(%0, %3, %4, %4)
43+
return (%5))IR";
44+
45+
auto g = std::make_shared<torch::jit::Graph>();
46+
47+
torch::jit::parseIR(graph, &*g);
48+
49+
// Input Tensor needs to be 4D for TensorRT upsample_nearest2d
50+
auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA});
51+
52+
auto jit_in = at::clone(in);
53+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
54+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
55+
56+
auto trt_in = at::clone(in);
57+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
58+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
59+
60+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
61+
62+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
63+
}
64+
65+
TEST(Converters, ATenUpsampleNearest2dConvertsCorrectly2dOutputSize) {
66+
const auto graph = R"IR(
67+
graph(%0 : Tensor):
68+
%2 : int = prim::Constant[value=10]()
69+
%3 : int[] = prim::ListConstruct(%2, %2)
70+
%4 : None = prim::Constant()
71+
%5 : Tensor = aten::upsample_nearest2d(%0, %3, %4, %4)
72+
return (%5))IR";
73+
74+
auto g = std::make_shared<torch::jit::Graph>();
75+
76+
torch::jit::parseIR(graph, &*g);
77+
78+
// Input Tensor needs to be 4D for TensorRT upsample_nearest2d
79+
auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA});
80+
81+
auto jit_in = at::clone(in);
82+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
83+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
84+
85+
auto trt_in = at::clone(in);
86+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
87+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
88+
89+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
90+
91+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
92+
}
93+
94+
TEST(Converters, ATenUpsampleNearest3dConvertsCorrectly1dOutputSize) {
95+
const auto graph = R"IR(
96+
graph(%0 : Tensor):
97+
%2 : int = prim::Constant[value=10]()
98+
%3 : int[] = prim::ListConstruct(%2, %2, %2)
99+
%4 : None = prim::Constant()
100+
%5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4)
101+
return (%5))IR";
102+
103+
auto g = std::make_shared<torch::jit::Graph>();
104+
105+
torch::jit::parseIR(graph, &*g);
106+
107+
// Input Tensor needs to be 5D for TensorRT upsample_nearest3d
108+
auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA});
109+
110+
auto jit_in = at::clone(in);
111+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
112+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
113+
114+
auto trt_in = at::clone(in);
115+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
116+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
117+
118+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
119+
120+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
121+
}
122+
123+
TEST(Converters, ATenUpsampleNearest3dConvertsCorrectly3dOutputSize) {
124+
const auto graph = R"IR(
125+
graph(%0 : Tensor):
126+
%2 : int = prim::Constant[value=10]()
127+
%3 : int[] = prim::ListConstruct(%2, %2, %2)
128+
%4 : None = prim::Constant()
129+
%5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4)
130+
return (%5))IR";
131+
132+
auto g = std::make_shared<torch::jit::Graph>();
133+
134+
torch::jit::parseIR(graph, &*g);
135+
136+
// Input Tensor needs to be 5D for TensorRT upsample_nearest3d
137+
auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA});
138+
139+
auto jit_in = at::clone(in);
140+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
141+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
142+
143+
auto trt_in = at::clone(in);
144+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
145+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
146+
147+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
148+
149+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
150+
}

0 commit comments

Comments
 (0)