Skip to content

Commit 0449068

Browse files
committed
Dynamo test cases error correction and adding support for int64 indices conversion to int32 in TRT10 for tests
1 parent 19049bc commit 0449068

File tree

3 files changed

+35
-21
lines changed

3 files changed

+35
-21
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,12 @@ def scatter(
407407
index_shape = index.shape
408408
index_shape_list = list(index.shape)
409409
if not (isinstance(index, TRTTensor)):
410+
if isinstance(index, torch.Tensor):
411+
if index.dtype == torch.int64:
412+
index = index.to(torch.int32)
413+
elif isinstance(index, np.ndarray):
414+
if index.dtype == np.int64:
415+
index = index.astype(np.int32)
410416
index = get_trt_tensor(ctx, index, f"_index_tensor")
411417
dim = get_positive_dim(dim, len(input_shape))
412418
dynamic_shape = has_dynamic_shape(input.shape)

tests/py/dynamo/conversion/harness.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def run_test(
226226
check_dtype=True,
227227
use_dynamo_tracer=False,
228228
enable_passes=False,
229+
int32_reqd=False,
229230
):
230231
mod.eval()
231232
mod = self.generate_graph(
@@ -245,10 +246,19 @@ def run_test(
245246

246247
num_inputs = len(inputs)
247248
trt_inputs = inputs
249+
dtype_to_change = []
250+
if int32_reqd:
251+
dtype_to_change = [torch.int64, torch.float64]
252+
else:
253+
dtype_to_change = [
254+
torch.float64,
255+
]
248256
for num_input in range(num_inputs):
249257
input = inputs[num_input]
250-
if input.dtype is torch.float64:
251-
dtype_32bit = torch.float32
258+
if input.dtype in dtype_to_change:
259+
dtype_32bit = (
260+
torch.float32 if (input.dtype == torch.float64) else torch.int32
261+
)
252262
# should we modify graph here to insert clone nodes?
253263
# ideally not required
254264
trt_inputs = (
@@ -360,4 +370,4 @@ def run_test_with_dynamic_shape(
360370
# Since the lowering is based on optimal shape. We need to test with
361371
# different shape(for ex. max shape) for testing dynamic shape
362372
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
363-
super().run_test(mod, inputs_max, interp, rtol, atol)
373+
super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol)

tests/py/dynamo/conversion/test_scatter_aten.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,7 @@ def forward(self, input):
4545

4646
input = torch.zeros(3, 5, dtype=torch.int32)
4747
inputs = [input]
48-
self.run_test(
49-
TestModule(),
50-
inputs,
51-
)
48+
self.run_test(TestModule(), inputs, int32_reqd=True)
5249

5350
@parameterized.expand(
5451
[
@@ -78,10 +75,7 @@ def forward(self, input, index):
7875

7976
input = torch.zeros(3, 5, dtype=torch.int32)
8077
inputs = [input, index]
81-
self.run_test(
82-
TestModule(),
83-
inputs,
84-
)
78+
self.run_test(TestModule(), inputs, int32_reqd=True)
8579

8680

8781
class TestScatterSrcConverter(DispatchTestCase):
@@ -113,8 +107,18 @@ class TestScatterSrcConverter(DispatchTestCase):
113107
),
114108
# These are special cases where in the harness.py code might need to be changed to input cuda_inputs
115109
# In that case below two test cases would also require index and src to be on cuda
116-
# ("scatter_one_dim_indexOne_constant_src", 1, torch.tensor([[0, 1, 2, 0]]), torch.tensor([[1, 2, 3, 4]], dtype=torch.int32)),
117-
# ("scatter_one_dim_indexTwo_constant_src", 1, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32)),
110+
(
111+
"scatter_one_dim_indexOne_constant_src",
112+
1,
113+
torch.tensor([[0, 1, 2, 0]]),
114+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
115+
),
116+
(
117+
"scatter_one_dim_indexTwo_constant_src",
118+
1,
119+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
120+
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
121+
),
118122
]
119123
)
120124
def test_scatter_index_constant(self, _, dim, index, src):
@@ -128,10 +132,7 @@ def forward(self, input):
128132
input = torch.zeros(3, 5, dtype=torch.int32)
129133
inputs = [input]
130134
scatter = TestModule()
131-
self.run_test(
132-
TestModule(),
133-
inputs,
134-
)
135+
self.run_test(TestModule(), inputs, int32_reqd=True)
135136

136137
@parameterized.expand(
137138
[
@@ -171,10 +172,7 @@ def forward(self, input, index):
171172

172173
input = torch.zeros(3, 5, dtype=torch.int32)
173174
inputs = [input, index]
174-
self.run_test(
175-
TestModule(),
176-
inputs,
177-
)
175+
self.run_test(TestModule(), inputs, int32_reqd=True)
178176

179177

180178
if __name__ == "__main__":

0 commit comments

Comments
 (0)