Skip to content

Commit e776efb

Browse files
authored
refactor: Split elementwise tests (#1507)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent b1829b7 commit e776efb

File tree

10 files changed

+709
-635
lines changed

10 files changed

+709
-635
lines changed

tests/core/conversion/converters/BUILD

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ converter_test(
1111
name = "test_activation",
1212
)
1313

14+
converter_test(
15+
name = "test_add_sub_mul",
16+
)
17+
18+
converter_test(
19+
name = "test_atan2",
20+
)
21+
1422
converter_test(
1523
name = "test_batch_norm",
1624
)
@@ -31,6 +39,10 @@ converter_test(
3139
name = "test_clone",
3240
)
3341

42+
converter_test(
43+
name = "test_clamp",
44+
)
45+
3446
converter_test(
3547
name = "test_concat",
3648
)
@@ -47,10 +59,18 @@ converter_test(
4759
name = "test_copy",
4860
)
4961

62+
converter_test(
63+
name = "test_comparators",
64+
)
65+
5066
converter_test(
5167
name = "test_cumsum",
5268
)
5369

70+
converter_test(
71+
name = "test_div",
72+
)
73+
5474
converter_test(
5575
name = "test_einsum",
5676
)
@@ -147,17 +167,21 @@ test_suite(
147167
name = "converter_tests",
148168
tests = [
149169
":test_activation",
170+
":test_add_sub_mul",
171+
":test_atan2",
150172
":test_batch_norm",
151173
":test_bitwise",
152174
":test_cast",
175+
":test_clamp",
153176
":test_clone",
177+
":test_comparators",
154178
":test_concat",
155179
":test_constant_pad",
156180
":test_conv_deconv",
157181
":test_copy",
158182
":test_cumsum",
183+
":test_div",
159184
":test_einsum",
160-
":test_element_wise",
161185
":test_expand",
162186
":test_instance_norm",
163187
":test_interpolate",
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/torch.h"
8+
9+
using torch_tensorrt::tests::util::pointwise_test_helper;
10+
11+
TEST(Converters, ATenAddConvertsCorrectly) {
12+
const auto graph = R"IR(
13+
graph(%0 : Tensor, %1 : Tensor):
14+
%2 : int = prim::Constant[value=1]()
15+
%3 : Tensor = aten::add(%0, %1, %2)
16+
return (%3))IR";
17+
pointwise_test_helper(graph, false);
18+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
19+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
20+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
21+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
22+
}
23+
24+
TEST(Converters, ATenAddWithAlphaConvertsCorrectly) {
25+
const auto graph = R"IR(
26+
graph(%0 : Tensor, %1 : Tensor):
27+
%2 : float = prim::Constant[value=3.2]()
28+
%3 : Tensor = aten::add(%0, %1, %2)
29+
return (%3))IR";
30+
pointwise_test_helper(graph, false);
31+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
32+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
33+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
34+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
35+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
36+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
37+
}
38+
39+
TEST(Converters, ATenAddInplaceWithAlphaConvertsCorrectly) {
40+
const auto graph = R"IR(
41+
graph(%0 : Tensor, %1 : Tensor):
42+
%2 : float = prim::Constant[value=7.6]()
43+
%3 : Tensor = aten::add_(%0, %1, %2)
44+
return (%3))IR";
45+
pointwise_test_helper(graph, false);
46+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
47+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
48+
pointwise_test_helper(graph, false, false, {3, 4, 3}, {4, 3}, false, at::kFloat, at::kInt);
49+
}
50+
51+
TEST(Converters, ATenAddImplicitWithIntAlphaConvertsCorrectly) {
52+
const auto graph = R"IR(
53+
graph(%0 : Tensor, %1 : Tensor):
54+
%2 : int = prim::Constant[value=42]()
55+
%3 : Tensor = aten::add_(%0, %1, %2)
56+
return (%3))IR";
57+
pointwise_test_helper(graph, false, false, {2, 2}, {2, 2}, false, at::kInt, at::kInt);
58+
pointwise_test_helper(graph, false, false, {3, 4, 3}, {4, 3}, false, at::kInt, at::kInt);
59+
}
60+
61+
TEST(Converters, ATenAddWithScalarConvertsCorrectly) {
62+
const auto graph = R"IR(
63+
graph(%0 : Tensor):
64+
%2 : int = prim::Constant[value=1]()
65+
%scalar : float = prim::Constant[value=2.4]()
66+
%3 : Tensor = aten::add(%0, %scalar, %2)
67+
return (%3))IR";
68+
pointwise_test_helper(graph, true);
69+
}
70+
71+
TEST(Converters, ATenSubConvertsCorrectly) {
72+
const auto graph = R"IR(
73+
graph(%0 : Tensor, %1 : Tensor):
74+
%2 : int = prim::Constant[value=2.3]()
75+
%3 : Tensor = aten::sub(%0, %1, %2)
76+
return (%3))IR";
77+
pointwise_test_helper(graph, false);
78+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
79+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
80+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
81+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
82+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
83+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
84+
}
85+
86+
TEST(Converters, ATenRsubWithTensorConvertsCorrectly) {
87+
const auto graph = R"IR(
88+
graph(%0 : Tensor, %1 : Tensor):
89+
%2 : int = prim::Constant[value=2]()
90+
%3 : Tensor = aten::rsub(%0, %1, %2)
91+
return (%3))IR";
92+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
93+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
94+
pointwise_test_helper(graph, false, true, {4, 3, 3, 3}, {4, 3, 3, 3});
95+
pointwise_test_helper(graph, false, false, {4, 3, 3, 3}, {4, 3, 3, 3}, false, at::kInt, at::kFloat);
96+
pointwise_test_helper(graph, false, false, {4, 3, 3, 3}, {4, 3, 3, 3}, false, at::kInt, at::kInt);
97+
}
98+
99+
TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
100+
const auto graph = R"IR(
101+
graph(%0 : Tensor):
102+
%2 : int = prim::Constant[value=2]()
103+
%scalar : float = prim::Constant[value=2.4]()
104+
%3 : Tensor = aten::rsub(%0, %scalar, %2)
105+
return (%3))IR";
106+
pointwise_test_helper(graph, true, false, {4, 3, 3, 3});
107+
}
108+
109+
TEST(Converters, ATenRsubWithIntScalarConvertsCorrectly) {
110+
const auto graph = R"IR(
111+
graph(%0 : Tensor):
112+
%2 : int = prim::Constant[value=2]()
113+
%scalar : int = prim::Constant[value=8]()
114+
%3 : Tensor = aten::rsub(%0, %scalar, %2)
115+
return (%3))IR";
116+
pointwise_test_helper(graph, true, false, {4, 3, 3, 3}, {}, false, at::kInt);
117+
}
118+
119+
TEST(Converters, ATenMulConvertsCorrectly) {
120+
const auto graph = R"IR(
121+
graph(%0 : Tensor, %1 : Tensor):
122+
%2 : Tensor = aten::mul(%0, %1)
123+
return (%2))IR";
124+
pointwise_test_helper(graph, false);
125+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
126+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
127+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
128+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
129+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
130+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
131+
}
132+
133+
TEST(Converters, ATenSquareConvertsCorrectly) {
134+
const auto graph = R"IR(
135+
graph(%0 : Tensor):
136+
%1 : Tensor = aten::square(%0)
137+
return (%1))IR";
138+
pointwise_test_helper(graph, true);
139+
}
140+
141+
TEST(Converters, ATenMulWithScalarConvertsCorrectly) {
142+
const auto graph = R"IR(
143+
graph(%0 : Tensor):
144+
%scalar : float = prim::Constant[value=2.4]()
145+
%1 : Tensor = aten::mul(%0, %scalar)
146+
return (%1))IR";
147+
pointwise_test_helper(graph, true);
148+
}
149+
150+
TEST(Converters, ATenMulWithIntScalarConvertsCorrectly) {
151+
const auto graph = R"IR(
152+
graph(%0 : Tensor):
153+
%scalar : int = prim::Constant[value=2]()
154+
%1 : Tensor = aten::mul(%0, %scalar)
155+
return (%1))IR";
156+
pointwise_test_helper(graph, true, false, {5}, {5}, false, at::kInt);
157+
}
158+
159+
TEST(Converters, ATenPowTensorConvertsCorrectly) {
160+
const auto graph = R"IR(
161+
graph(%x.1 : Tensor, %x2.1 : Tensor):
162+
%3 : Tensor = aten::pow(%x.1, %x2.1)
163+
return (%3))IR";
164+
pointwise_test_helper(graph, false);
165+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
166+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
167+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
168+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
169+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
170+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
171+
}
172+
173+
TEST(Converters, ATenPowScalarConvertsCorrectly) {
174+
const auto graph = R"IR(
175+
graph(%x.1 : Tensor):
176+
%2 : int = prim::Constant[value=2]()
177+
%3 : Tensor = aten::pow(%x.1, %2)
178+
return (%3))IR";
179+
pointwise_test_helper(graph, true);
180+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/torch.h"
8+
9+
TEST(Converters, ATenAtan2ConvertsCorrectly) {
10+
const auto graph = R"IR(
11+
graph(%x.0 : Tensor, %x.1 : Tensor):
12+
%2 : Tensor = aten::atan2(%x.0, %x.1)
13+
return (%2))IR";
14+
15+
auto g = std::make_shared<torch::jit::Graph>();
16+
torch::jit::parseIR(graph, g.get());
17+
18+
// Resize range to [-1, 1] to span multiple quadrants
19+
auto in_0 = -2 * at::rand({2, 3, 5, 5}, {at::kCUDA}) + 1;
20+
auto in_1 = -2 * at::rand({2, 3, 5, 5}, {at::kCUDA}) + 1;
21+
22+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
23+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
24+
25+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
26+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
27+
28+
ASSERT_TRUE(
29+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
30+
}
31+
32+
TEST(Converters, ATenAtan2ManagesPosInfCorrectly) {
33+
const auto graph = R"IR(
34+
graph(%x.0 : Tensor, %x.1 : Tensor):
35+
%2 : Tensor = aten::atan2(%x.0, %x.1)
36+
return (%2))IR";
37+
38+
auto g = std::make_shared<torch::jit::Graph>();
39+
torch::jit::parseIR(graph, g.get());
40+
41+
// Expecting PI/2
42+
auto in_0 = at::ones({4, 1, 7, 8}, {at::kCUDA});
43+
auto in_1 = at::zeros({4, 1, 7, 8}, {at::kCUDA});
44+
45+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
46+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
47+
48+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
49+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
50+
51+
ASSERT_TRUE(
52+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
53+
}
54+
55+
TEST(Converters, ATenAtan2ManagesNegInfCorrectly) {
56+
const auto graph = R"IR(
57+
graph(%x.0 : Tensor, %x.1 : Tensor):
58+
%2 : Tensor = aten::atan2(%x.0, %x.1)
59+
return (%2))IR";
60+
61+
auto g = std::make_shared<torch::jit::Graph>();
62+
torch::jit::parseIR(graph, g.get());
63+
64+
// Expecting -PI/2
65+
auto in_0 = -1 * at::ones({4, 1, 7, 8}, {at::kCUDA});
66+
auto in_1 = at::zeros({4, 1, 7, 8}, {at::kCUDA});
67+
68+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
69+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
70+
71+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
72+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
73+
74+
ASSERT_TRUE(
75+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
76+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/torch.h"
8+
9+
using torch_tensorrt::tests::util::pointwise_test_helper;
10+
11+
TEST(Converters, ATenClampMinConvertsCorrectly) {
12+
const auto graph = R"IR(
13+
graph(%x.1 : Tensor):
14+
%2 : float = prim::Constant[value=1.5]()
15+
%3 : None = prim::Constant()
16+
%4 : Tensor = aten::clamp(%x.1, %2, %3)
17+
return (%4))IR";
18+
pointwise_test_helper(graph, true);
19+
}
20+
21+
TEST(Converters, ATenClampMaxConvertsCorrectly) {
22+
const auto graph = R"IR(
23+
graph(%x.1 : Tensor):
24+
%2 : float = prim::Constant[value=3.5]()
25+
%3 : None = prim::Constant()
26+
%4 : Tensor = aten::clamp(%x.1, %3, %2)
27+
return (%4))IR";
28+
pointwise_test_helper(graph, true);
29+
}
30+
31+
TEST(Converters, ATenClampMinMaxConvertsCorrectly) {
32+
const auto graph = R"IR(
33+
graph(%x.1 : Tensor):
34+
%2 : float = prim::Constant[value=3.5]()
35+
%3 : float = prim::Constant[value=1.5]()
36+
%4 : Tensor = aten::clamp(%x.1, %3, %2)
37+
return (%4))IR";
38+
pointwise_test_helper(graph, true);
39+
}
40+
41+
TEST(Converters, ATenClampMinimumConvertsCorrectly) {
42+
const auto graph = R"IR(
43+
graph(%x.1 : Tensor):
44+
%2 : float = prim::Constant[value=2.5]()
45+
%4 : Tensor = aten::clamp_min(%x.1, %2)
46+
return (%4))IR";
47+
pointwise_test_helper(graph, true);
48+
}
49+
50+
TEST(Converters, ATenClampMaximumConvertsCorrectly) {
51+
const auto graph = R"IR(
52+
graph(%x.1 : Tensor):
53+
%2 : float = prim::Constant[value=2.5]()
54+
%4 : Tensor = aten::clamp_max(%x.1, %2)
55+
return (%4))IR";
56+
pointwise_test_helper(graph, true);
57+
}

0 commit comments

Comments
 (0)