1
+ #include < string>
2
+ #include " gtest/gtest.h"
3
+ #include " torch/csrc/jit/ir/irparser.h"
4
+ #include " tests/util/util.h"
5
+ #include " core/compiler.h"
6
+
7
+ TEST (Converters, ATenUpsampleNearest1dConvertsCorrectly) {
8
+ const auto graph = R"IR(
9
+ graph(%0 : Tensor):
10
+ %2 : int = prim::Constant[value=10]()
11
+ %3 : int[] = prim::ListConstruct(%2)
12
+ %4 : None = prim::Constant()
13
+ %5 : Tensor = aten::upsample_nearest1d(%0, %3, %4)
14
+ return (%5))IR" ;
15
+
16
+ auto g = std::make_shared<torch::jit::Graph>();
17
+
18
+ torch::jit::parseIR (graph, &*g);
19
+
20
+ // Input Tensor needs to be 3D for TensorRT upsample_nearest1d
21
+ auto in = at::randint (1 , 10 , {10 , 2 , 2 }, {at::kCUDA });
22
+
23
+ auto jit_in = at::clone (in);
24
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
25
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
26
+
27
+ auto trt_in = at::clone (in);
28
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
29
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
30
+
31
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
32
+
33
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
34
+ }
35
+
36
+ TEST (Converters, ATenUpsampleNearest2dConvertsCorrectly1dOutputSize) {
37
+ const auto graph = R"IR(
38
+ graph(%0 : Tensor):
39
+ %2 : int = prim::Constant[value=10]()
40
+ %3 : int[] = prim::ListConstruct(%2, %2)
41
+ %4 : None = prim::Constant()
42
+ %5 : Tensor = aten::upsample_nearest2d(%0, %3, %4, %4)
43
+ return (%5))IR" ;
44
+
45
+ auto g = std::make_shared<torch::jit::Graph>();
46
+
47
+ torch::jit::parseIR (graph, &*g);
48
+
49
+ // Input Tensor needs to be 4D for TensorRT upsample_nearest2d
50
+ auto in = at::randint (1 , 10 , {10 , 2 , 2 , 2 }, {at::kCUDA });
51
+
52
+ auto jit_in = at::clone (in);
53
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
54
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
55
+
56
+ auto trt_in = at::clone (in);
57
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
58
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
59
+
60
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
61
+
62
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
63
+ }
64
+
65
+ TEST (Converters, ATenUpsampleNearest2dConvertsCorrectly2dOutputSize) {
66
+ const auto graph = R"IR(
67
+ graph(%0 : Tensor):
68
+ %2 : int = prim::Constant[value=10]()
69
+ %3 : int[] = prim::ListConstruct(%2, %2)
70
+ %4 : None = prim::Constant()
71
+ %5 : Tensor = aten::upsample_nearest2d(%0, %3, %4, %4)
72
+ return (%5))IR" ;
73
+
74
+ auto g = std::make_shared<torch::jit::Graph>();
75
+
76
+ torch::jit::parseIR (graph, &*g);
77
+
78
+ // Input Tensor needs to be 4D for TensorRT upsample_nearest2d
79
+ auto in = at::randint (1 , 10 , {10 , 2 , 2 , 2 }, {at::kCUDA });
80
+
81
+ auto jit_in = at::clone (in);
82
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
83
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
84
+
85
+ auto trt_in = at::clone (in);
86
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
87
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
88
+
89
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
90
+
91
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
92
+ }
93
+
94
+ TEST (Converters, ATenUpsampleNearest3dConvertsCorrectly1dOutputSize) {
95
+ const auto graph = R"IR(
96
+ graph(%0 : Tensor):
97
+ %2 : int = prim::Constant[value=10]()
98
+ %3 : int[] = prim::ListConstruct(%2, %2, %2)
99
+ %4 : None = prim::Constant()
100
+ %5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4)
101
+ return (%5))IR" ;
102
+
103
+ auto g = std::make_shared<torch::jit::Graph>();
104
+
105
+ torch::jit::parseIR (graph, &*g);
106
+
107
+ // Input Tensor needs to be 5D for TensorRT upsample_nearest3d
108
+ auto in = at::randint (1 , 10 , {10 , 2 , 2 , 2 , 2 }, {at::kCUDA });
109
+
110
+ auto jit_in = at::clone (in);
111
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
112
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
113
+
114
+ auto trt_in = at::clone (in);
115
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
116
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
117
+
118
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
119
+
120
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
121
+ }
122
+
123
+ TEST (Converters, ATenUpsampleNearest3dConvertsCorrectly3dOutputSize) {
124
+ const auto graph = R"IR(
125
+ graph(%0 : Tensor):
126
+ %2 : int = prim::Constant[value=10]()
127
+ %3 : int[] = prim::ListConstruct(%2, %2, %2)
128
+ %4 : None = prim::Constant()
129
+ %5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4)
130
+ return (%5))IR" ;
131
+
132
+ auto g = std::make_shared<torch::jit::Graph>();
133
+
134
+ torch::jit::parseIR (graph, &*g);
135
+
136
+ // Input Tensor needs to be 5D for TensorRT upsample_nearest3d
137
+ auto in = at::randint (1 , 10 , {10 , 2 , 2 , 2 , 2 }, {at::kCUDA });
138
+
139
+ auto jit_in = at::clone (in);
140
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
141
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
142
+
143
+ auto trt_in = at::clone (in);
144
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
145
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
146
+
147
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
148
+
149
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
150
+ }
0 commit comments