Skip to content

Commit 486c0b2

Browse files
committed
changing the test to compare size instead of elements
1 parent e099dbb commit 486c0b2

File tree

3 files changed

+99
-37
lines changed

3 files changed

+99
-37
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def aten_ops_rand(
7373
kwargs: Dict[str, Argument],
7474
name: str,
7575
) -> Union[TRTTensor, Sequence[TRTTensor]]:
76-
return np.random.rand(*args)
76+
return np.random.rand(*args[0])
7777

7878

7979
@dynamo_tensorrt_converter(
@@ -86,7 +86,7 @@ def aten_ops_randn(
8686
kwargs: Dict[str, Argument],
8787
name: str,
8888
) -> Union[TRTTensor, Sequence[TRTTensor]]:
89-
return np.random.randn(*args)
89+
return np.random.randn(*args[0])
9090

9191

9292
def randperm_validator(randperm_node: Node) -> bool:
@@ -117,4 +117,4 @@ def aten_ops_randperm(
117117
kwargs: Dict[str, Argument],
118118
name: str,
119119
) -> Union[TRTTensor, Sequence[TRTTensor]]:
120-
return np.random.permutation(*args)
120+
return np.random.permutation(args[0])

tests/py/dynamo/conversion/harness.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def run_test_custom_compare_results(
146146
res_trt = trt_mod(*cuda_inputs).cpu()
147147
res_cpu = mod(*inputs)
148148
assert len(res_trt) == len(res_cpu)
149-
assert len(res_cpu) == len(comparators)
150149
for output_trt, output_cpu, comparator in zip(
151150
res_trt, res_cpu, comparators
152151
):
@@ -252,6 +251,40 @@ def run_test(
252251
check_dtype,
253252
)
254253

254+
def run_test_comparator(
255+
self,
256+
mod,
257+
inputs,
258+
expected_ops,
259+
comparators: List[Tuple[Callable, List]],
260+
precision=torch.float32,
261+
output_dtypes=None,
262+
use_dynamo_tracer=False,
263+
enable_passes=False,
264+
):
265+
mod.eval()
266+
mod = self.generate_graph(
267+
mod,
268+
inputs,
269+
use_dynamo_tracer=use_dynamo_tracer,
270+
enable_passes=enable_passes,
271+
)
272+
# Previous instance of the interpreter auto-casted 64-bit inputs
273+
# We replicate this behavior here
274+
compilation_settings = CompilationSettings(
275+
precision=precision, truncate_long_and_double=True
276+
)
277+
278+
interp = TRTInterpreter(
279+
mod,
280+
Input.from_tensors(inputs),
281+
output_dtypes=output_dtypes,
282+
compilation_settings=compilation_settings,
283+
)
284+
super().run_test_custom_compare_results(
285+
mod, inputs, expected_ops, interp, comparators
286+
)
287+
255288
def run_test_with_dynamic_shape(
256289
self,
257290
mod,

tests/py/dynamo/conversion/test_rand_aten.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from parameterized import parameterized
55
from torch.testing._internal.common_utils import TestCase, run_tests
66

7+
from .harness import DispatchTestCase
8+
79
rand_ops = [
810
(
911
"rand_one_dimension",
@@ -13,7 +15,7 @@
1315
(
1416
"rand_two_dimension",
1517
(lambda shape: torch.ops.aten.rand(shape)),
16-
[2, 3],
18+
[1, 2],
1719
),
1820
(
1921
"rand_three_dimension",
@@ -35,25 +37,29 @@
3537
(lambda shape: torch.ops.aten.randn(shape)),
3638
[2, 3, 4],
3739
),
40+
]
41+
42+
43+
rand_perm_ops = [
3844
(
3945
"randperm_one_case",
4046
(lambda x: torch.ops.aten.randperm(x)),
41-
1,
47+
[1],
4248
),
4349
(
4450
"randperm_two_case",
4551
(lambda x: torch.ops.aten.randperm(x)),
46-
150,
52+
[150],
4753
),
4854
(
4955
"randperm_three_case",
5056
(lambda x: torch.ops.aten.randperm(x)),
51-
1500,
57+
[1500],
5258
),
5359
]
5460

5561

56-
class TestRandConverter(TestCase):
62+
class TestRandConverter(DispatchTestCase):
5763
@parameterized.expand(
5864
[
5965
(
@@ -64,41 +70,64 @@ class TestRandConverter(TestCase):
6470
for rand_op in rand_ops
6571
]
6672
)
67-
def test_rand(self, _, op, shape_or_input):
73+
def test_rand(self, name, op, shape_or_input):
6874
class TestModule(nn.Module):
69-
def __init__(self, rand_op, size):
75+
def __init__(self):
7076
super().__init__()
71-
self.rand_op = rand_op
72-
self.size = size
7377

74-
def forward(self):
75-
return self.rand_op(self.size)
78+
def forward(self, x):
79+
shape_or_input[0] = x.shape[0]
80+
return op(shape_or_input)
7681

77-
rand_model = TestModule(op, shape_or_input)
78-
# cannot use self.run_test() since it expects input in form of tensor
82+
rand_model = TestModule()
7983

80-
fx_graph = torch.fx.symbolic_trace(grid_model)
81-
torch._dynamo.reset()
82-
83-
optimized_model = torch_tensorrt.compile(
84-
fx_graph,
85-
"torch_compile",
86-
None,
87-
min_block_size=1,
88-
pass_through_build_failures=True,
89-
truncate_long_and_double=True,
90-
debug=True,
84+
inputs = [torch.randint(1, 3, shape_or_input, dtype=torch.int32)]
85+
comparator_shape = lambda x, y, check_dtype: x.shape == y.shape and (
86+
x.dtype == y.dtype if check_dtype else True
87+
)
88+
expected_ops = []
89+
self.run_test_comparator(
90+
rand_model,
91+
inputs,
92+
expected_ops,
93+
[(comparator_shape, [True])],
94+
use_dynamo_tracer=True,
9195
)
92-
optimized_model_results = optimized_model().detach().cpu()
93-
torch_model_results = fx_graph().detach().cpu()
94-
max_diff = float(
95-
torch.max(torch.abs(optimized_model_results - torch_model_results))
96+
97+
@parameterized.expand(
98+
[
99+
(
100+
rand_op[0],
101+
rand_op[1],
102+
rand_op[2],
103+
)
104+
for rand_op in rand_perm_ops
105+
]
106+
)
107+
def test_rand(self, name, op, shape_or_input):
108+
class TestModule(nn.Module):
109+
def __init__(self):
110+
super().__init__()
111+
112+
def forward(self, x):
113+
shape_or_input[0] = x.shape[0]
114+
return op(shape_or_input[0])
115+
116+
rand_model = TestModule()
117+
# cannot use self.run_test() since it expects input in form of tensor
118+
119+
inputs = [torch.randint(1, 3, shape_or_input, dtype=torch.int32)]
120+
comparator_shape = lambda x, y, check_dtype: x.shape == y.shape and (
121+
x.dtype == y.dtype if check_dtype else True
96122
)
97-
self.assertAlmostEqual(
98-
max_diff,
99-
0,
100-
4,
101-
f"TRT outputs don't match with the original model.",
123+
expected_ops = []
124+
# TRT-np returns int32 while torch returns float32
125+
self.run_test_comparator(
126+
rand_model,
127+
inputs,
128+
expected_ops,
129+
[(comparator_shape, [False])],
130+
use_dynamo_tracer=True,
102131
)
103132

104133

0 commit comments

Comments
 (0)