@@ -12,7 +12,9 @@ void pointwise_test_helper(
12
12
std::vector<int64_t > shape1 = {5 },
13
13
std::vector<int64_t > shape2 = {5 },
14
14
bool negative_input = false ,
15
- bool int_tensors = false ) {
15
+ bool int_tensors = false ,
16
+ bool float_int_tensors = false ,
17
+ bool int_float_tensors = false ) {
16
18
auto g = std::make_shared<torch::jit::Graph>();
17
19
torch::jit::parseIR (graph_ir, g.get ());
18
20
@@ -27,11 +29,24 @@ void pointwise_test_helper(
27
29
if (!singleInput) {
28
30
torch_inputs.push_back (at::randint (1 , 5 , shape2, {at::kCUDA }));
29
31
}
32
+
33
+ TORCHTRT_CHECK (!((int_tensors && (float_int_tensors || int_float_tensors)) || (float_int_tensors && int_float_tensors)),
34
+ " Invalid test configuration, only one of int_tensors, float_int_tensors, int_float_tensors can be true" );
35
+
30
36
if (int_tensors){
31
37
for (size_t i = 0UL ; i < torch_inputs.size (); ++i){
32
38
torch_inputs[i] = torch_inputs[i].to (at::kInt );
33
39
}
40
+ } else if (float_int_tensors) {
41
+ TORCHTRT_CHECK (!singleInput, " float_int_tensors tests require two inputs" );
42
+ torch_inputs[0 ] = torch_inputs[0 ].to (at::kFloat );
43
+ torch_inputs[1 ] = torch_inputs[1 ].to (at::kInt );
44
+ } else if (int_float_tensors) {
45
+ TORCHTRT_CHECK (!singleInput, " int_float_tensors tests require two inputs" );
46
+ torch_inputs[0 ] = torch_inputs[0 ].to (at::kInt );
47
+ torch_inputs[1 ] = torch_inputs[1 ].to (at::kFloat );
34
48
}
49
+
35
50
auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
36
51
auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, torch_inputs);
37
52
@@ -62,6 +77,8 @@ TEST(Converters, ATenAddConvertsCorrectly) {
62
77
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
63
78
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
64
79
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
80
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
81
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
65
82
}
66
83
67
84
TEST (Converters, ATenAddWithAlphaConvertsCorrectly) {
@@ -75,9 +92,11 @@ TEST(Converters, ATenAddWithAlphaConvertsCorrectly) {
75
92
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
76
93
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
77
94
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
95
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
96
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
78
97
}
79
98
80
- TEST (Converters, ATenAddImplicitWithAlphaConvertsCorrectly ) {
99
+ TEST (Converters, ATenAddInplaceWithAlphaConvertsCorrectly ) {
81
100
const auto graph = R"IR(
82
101
graph(%0 : Tensor, %1 : Tensor):
83
102
%2 : float = prim::Constant[value=7.6]()
@@ -109,6 +128,8 @@ TEST(Converters, ATenSubConvertsCorrectly) {
109
128
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
110
129
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
111
130
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
131
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
132
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
112
133
}
113
134
114
135
TEST (Converters, ATenMulConvertsCorrectly) {
@@ -121,6 +142,8 @@ TEST(Converters, ATenMulConvertsCorrectly) {
121
142
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
122
143
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
123
144
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
145
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
146
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
124
147
}
125
148
126
149
TEST (Converters, ATenMulWithScalarConvertsCorrectly) {
@@ -151,6 +174,8 @@ TEST(Converters, ATenDivConvertsCorrectly) {
151
174
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
152
175
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
153
176
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
177
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
178
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
154
179
}
155
180
156
181
TEST (Converters, ATenDivWithScalarConvertsCorrectly) {
@@ -173,6 +198,8 @@ TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) {
173
198
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 }, true );
174
199
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 }, true );
175
200
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 }, true );
201
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
202
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
176
203
}
177
204
178
205
TEST (Converters, ATenDivRoundingTruncConvertsCorrectly) {
@@ -186,6 +213,8 @@ TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
186
213
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 }, true );
187
214
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 }, true );
188
215
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 }, true );
216
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
217
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
189
218
}
190
219
191
220
TEST (Converters, ATenDivRoundingNoneConvertsCorrectly) {
@@ -211,6 +240,8 @@ TEST(Converters, ATenPowTensorConvertsCorrectly) {
211
240
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
212
241
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
213
242
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
243
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
244
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
214
245
}
215
246
216
247
TEST (Converters, ATenPowScalarConvertsCorrectly) {
@@ -251,6 +282,8 @@ TEST(Converters, ATenFloorDivideConvertsCorrectly) {
251
282
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
252
283
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
253
284
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
285
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
286
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
254
287
}
255
288
256
289
TEST (Converters, ATenFloorDivideWithScalarConvertsCorrectly) {
0 commit comments