Skip to content

Commit c3bdf9f

Browse files
committed
Correcting rand test cases
1 parent bc873ce commit c3bdf9f

File tree

1 file changed

+37
-13
lines changed

1 file changed

+37
-13
lines changed

tests/py/dynamo/conversion/test_rand_aten.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import torch
22
import torch.nn as nn
33
from parameterized import parameterized
4-
from torch.testing._internal.common_utils import run_tests
5-
6-
from .harness import DispatchTestCase
4+
from torch.testing._internal.common_utils import TestCase, run_tests
5+
import torch_tensorrt
76

87
rand_ops = [
98
(
@@ -39,43 +38,68 @@
3938
(
4039
"randperm_one_case",
4140
(lambda x: torch.ops.aten.randperm(x)),
42-
[1],
41+
1,
4342
),
4443
(
4544
"randperm_two_case",
4645
(lambda x: torch.ops.aten.randperm(x)),
47-
[150],
46+
150,
4847
),
4948
(
5049
"randperm_three_case",
5150
(lambda x: torch.ops.aten.randperm(x)),
52-
[1500],
51+
1500,
5352
),
5453
]
5554

5655

57-
class TestRandConverter(DispatchTestCase):
56+
class TestRandConverter(TestCase):
5857
@parameterized.expand(
5958
[
6059
(
6160
rand_op[0],
6261
rand_op[1],
62+
rand_op[2],
6363
)
6464
for rand_op in rand_ops
6565
]
6666
)
6767
def test_rand(self, _, op, shape_or_input):
6868
class TestModule(nn.Module):
69-
def __init__(self, rand_op):
69+
def __init__(self, rand_op, size):
7070
super().__init__()
7171
self.rand_op = rand_op
72+
self.size = size
73+
74+
def forward(self):
75+
return self.rand_op(self.size)
7276

73-
def forward(self, x):
74-
return self.rand_op(x)
77+
grid_model = TestModule(op, shape_or_input)
78+
#cannot use self.run_test() since it expects input in form of tensor
79+
80+
#self.run_test(grid_model, None)
81+
fx_graph = torch.fx.symbolic_trace(grid_model)
82+
torch._dynamo.reset()
7583

76-
inputs = [shape_or_input]
77-
grid_model = TestModule(op)
78-
self.run_test(grid_model, inputs)
84+
optimized_model = torch_tensorrt.compile(fx_graph,
85+
"torch_compile",
86+
None,
87+
min_block_size=1,
88+
pass_through_build_failures=True,
89+
truncate_long_and_double=True,
90+
debug=True,
91+
)
92+
optimized_model_results = optimized_model().detach().cpu()
93+
torch_model_results = fx_graph().detach().cpu()
94+
max_diff = float(
95+
torch.max(torch.abs(optimized_model_results - torch_model_results))
96+
)
97+
self.assertAlmostEqual(
98+
max_diff,
99+
0,
100+
4,
101+
f"TRT outputs don't match with the original model.",
102+
)
79103

80104

81105
if __name__ == "__main__":

0 commit comments

Comments
 (0)