Skip to content

Commit e43833d

Browse files
switch from fx.symbolic_trace to dynamo_trace for converter test part-1 (#3261)
1 parent 3ecd5aa commit e43833d

15 files changed

+374
-282
lines changed

tests/py/dynamo/conversion/harness.py

Lines changed: 124 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
from torch_tensorrt._Device import Device
1515
from torch_tensorrt._enums import dtype
1616
from torch_tensorrt.dynamo import _defaults
17+
from torch_tensorrt.dynamo._defaults import default_device
1718
from torch_tensorrt.dynamo._settings import CompilationSettings
19+
from torch_tensorrt.dynamo._tracer import get_dynamic_shapes_args
1820

1921
# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
2022
from torch_tensorrt.dynamo.conversion import TRTInterpreter
23+
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
2124
from torch_tensorrt.dynamo.lowering import (
2225
get_decompositions,
2326
post_lowering,
@@ -29,6 +32,77 @@
2932
_LOGGER: logging.Logger = logging.getLogger(__name__)
3033

3134

35+
# this method is only used in our converter test to infer the module output dtypes via dummy inference
36+
# which is due to fx.symbolic_trace does not have the meta['val'] info in the node
37+
# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace
38+
def infer_module_output_dtypes_for_test(
39+
module: torch.fx.GraphModule,
40+
inputs: Sequence[Input],
41+
device: Device,
42+
kwarg_inputs: Optional[dict[str, Any]] = None,
43+
truncate_double: bool = False,
44+
) -> List[dtype]:
45+
"""
46+
This function performs model inference to determine the output dtypes
47+
and truncates them accordingly. inputs can be either arg_inputs or flattened input list.
48+
If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input.
49+
"""
50+
# TODO: We can also determine output dtypes from the module.graph based on node metadata.
51+
# However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata,
52+
# so we stick to the model inference approach currently.
53+
with unset_fake_temporarily():
54+
# Get the device on which the model exists
55+
# For large models, this can be done on CPU to save GPU memory allocation for TRT.
56+
device = get_model_device(module)
57+
torch_inputs = get_torch_inputs(inputs, device)
58+
if kwarg_inputs is None:
59+
kwarg_inputs = {}
60+
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
61+
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
62+
if not isinstance(module_outputs, (list, tuple)):
63+
module_outputs = [module_outputs]
64+
65+
# Int64 outputs can sometimes be generated from within other operators
66+
# such as aten.sum - such outputs can be truncated
67+
output_dtypes = []
68+
for output in module_outputs:
69+
output_ = output
70+
# We don't need to check if output is nested here because the input module will be flattened
71+
if not isinstance(output, torch.Tensor):
72+
if isinstance(output, str):
73+
raise ValueError(
74+
f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
75+
)
76+
else:
77+
output_ = torch.tensor(output)
78+
79+
if truncate_double and output_.dtype == dtype.float64:
80+
output_dtypes.append(dtype.float32)
81+
else:
82+
output_dtypes.append(dtype._from(output_.dtype))
83+
84+
return output_dtypes
85+
86+
87+
# this is to enable dynamo tracer as True in the converter test files batch by batch
88+
def get_use_dynamo_tracer(use_dynamo_tracer: Any) -> bool:
89+
# if in our converter tests we specifically set use_dynamo_tracer field, honor it
90+
if use_dynamo_tracer is not None and isinstance(use_dynamo_tracer, bool):
91+
return use_dynamo_tracer
92+
# if in our converter tests, we did not specify use_dynamo_tracer field
93+
import inspect
94+
import os
95+
import re
96+
97+
filename = os.path.basename(inspect.stack()[2].filename)
98+
# enable converter test files which starts with test_a*.py to use dynamo tracer
99+
pattern = re.compile("^test_([a])+")
100+
if pattern.match(filename):
101+
return True
102+
else:
103+
return False
104+
105+
32106
# this method is only used in our converter test to infer the module output dtypes via dummy inference
33107
# which is due to fx.symbolic_trace does not have the meta['val'] info in the node
34108
# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace
@@ -277,14 +351,26 @@ def generate_graph(
277351
enable_passes: bool,
278352
propagate_shapes: bool = False,
279353
settings: CompilationSettings = CompilationSettings(),
354+
torch_export_dynamic_shapes: Optional[Any] = None,
280355
):
281356
mod = mod.eval()
282357
if use_dynamo_tracer:
283-
exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs))
284-
exported_program = pre_export_lowering(exported_program, settings)
285-
exported_program = exported_program.run_decompositions(
286-
get_decompositions(False)
358+
if torch_export_dynamic_shapes is None:
359+
torch_export_dynamic_shapes = get_dynamic_shapes_args(
360+
mod, original_inputs
361+
)
362+
device = default_device()
363+
torch_export_inputs = get_torch_inputs(original_inputs, device)
364+
exported_program = torch.export.export(
365+
mod,
366+
tuple(torch_export_inputs),
367+
dynamic_shapes=torch_export_dynamic_shapes,
287368
)
369+
if enable_passes:
370+
exported_program = pre_export_lowering(exported_program, settings)
371+
exported_program = exported_program.run_decompositions(
372+
get_decompositions(False)
373+
)
288374
fx_module = exported_program.module()
289375
else:
290376
fx_module = torch.fx.symbolic_trace(mod)
@@ -313,13 +399,15 @@ def run_test(
313399
atol=ATOL,
314400
precision=dtype.f32,
315401
check_dtype=True,
316-
use_dynamo_tracer=False,
402+
use_dynamo_tracer=None,
317403
enable_passes=False,
318404
propagate_shapes=False,
319405
int32_reqd=False,
320406
make_refittable=False,
321407
):
322-
408+
# TODO: lan to remove this and set use_dynamo_traccer to True by default
409+
# once all the converter test files are moved to use_dynamo_tracer
410+
use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer)
323411
# Previous instance of the interpreter auto-casted 64-bit inputs
324412
# We replicate this behavior here
325413
compilation_settings = CompilationSettings(
@@ -366,12 +454,18 @@ def run_test(
366454

367455
output_dtypes = None
368456
if check_dtype:
369-
output_dtypes = infer_module_output_dtypes_for_test(
370-
mod,
371-
input_specs,
372-
compilation_settings.device,
373-
truncate_double=compilation_settings.truncate_double,
374-
)
457+
if use_dynamo_tracer:
458+
output_dtypes = infer_module_output_dtypes(
459+
mod,
460+
truncate_double=compilation_settings.truncate_double,
461+
)
462+
else:
463+
output_dtypes = infer_module_output_dtypes_for_test(
464+
mod,
465+
input_specs,
466+
compilation_settings.device,
467+
truncate_double=compilation_settings.truncate_double,
468+
)
375469

376470
_LOGGER.debug(f"Compilation settings: {compilation_settings}")
377471
_LOGGER.debug(f"Inputs: {input_specs}")
@@ -441,37 +535,47 @@ def run_test_with_dynamic_shape(
441535
rtol=RTOL,
442536
atol=ATOL,
443537
output_dtypes=None,
444-
use_dynamo_tracer=False,
538+
use_dynamo_tracer=None,
445539
enable_passes=False,
446540
use_example_tensors=True,
447541
pyt_inputs=None,
448542
propagate_shapes=False,
449543
check_dtype=True,
450544
make_refittable=False,
545+
torch_export_dynamic_shapes=None,
451546
):
547+
# TODO: lan to remove this and set use_dynamo_traccer to True by default
548+
# once all the converter test files are moved to use_dynamo_tracer
549+
use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer)
452550

453551
# Previous instance of the interpreter auto-casted 64-bit inputs
454552
# We replicate this behavior here
455553
compilation_settings = CompilationSettings(
456554
truncate_double=True, make_refittable=make_refittable
457555
)
458-
459556
mod = self.generate_graph(
460557
mod,
461558
input_specs,
462559
use_dynamo_tracer=use_dynamo_tracer,
463560
enable_passes=enable_passes,
464561
propagate_shapes=propagate_shapes,
465562
settings=compilation_settings,
563+
torch_export_dynamic_shapes=torch_export_dynamic_shapes,
466564
)
467565

468566
if check_dtype:
469-
output_dtypes = infer_module_output_dtypes_for_test(
470-
mod,
471-
input_specs,
472-
compilation_settings.device,
473-
truncate_double=compilation_settings.truncate_double,
474-
)
567+
if use_dynamo_tracer:
568+
output_dtypes = infer_module_output_dtypes(
569+
mod,
570+
truncate_double=compilation_settings.truncate_double,
571+
)
572+
else:
573+
output_dtypes = infer_module_output_dtypes_for_test(
574+
mod,
575+
input_specs,
576+
compilation_settings.device,
577+
truncate_double=compilation_settings.truncate_double,
578+
)
475579

476580
interp = TRTInterpreter(
477581
mod,

tests/py/dynamo/conversion/test_acos_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def forward(self, input):
6666
(
6767
"3d_dim_dtype_float",
6868
(1, 1, 1),
69-
(1, 2, 3),
69+
(2, 2, 3),
7070
(3, 3, 3),
7171
torch.float,
7272
torch.float,

tests/py/dynamo/conversion/test_acosh_aten.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ def forward(self, input):
5858
(
5959
"3d_dim_dtype_float",
6060
(1, 1, 1),
61-
(1, 2, 3),
61+
(2, 2, 3),
6262
(3, 3, 3),
6363
torch.float,
6464
torch.float,
6565
),
6666
(
6767
"3d_dim_dtype_int32",
6868
(1, 1, 1),
69-
(1, 2, 4),
69+
(2, 2, 4),
7070
(2, 3, 5),
7171
torch.int32,
7272
torch.float,

tests/py/dynamo/conversion/test_any.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class TestAnyConverterDynamic(DispatchTestCase):
191191
(
192192
"3d_dynamic_float",
193193
(2, 1, 1),
194-
(2, 2, 1),
194+
(2, 2, 2),
195195
(3, 2, 4),
196196
torch.float,
197197
),
@@ -234,7 +234,7 @@ def forward(self, x):
234234
(
235235
"3d_dynamic_dim_float",
236236
(2, 1, 1),
237-
(2, 2, 1),
237+
(2, 2, 2),
238238
(3, 2, 4),
239239
torch.float,
240240
2,
@@ -252,7 +252,7 @@ def forward(self, x):
252252
(
253253
"3d_dynamic_dim_bool",
254254
(2, 1, 1),
255-
(2, 2, 1),
255+
(2, 2, 2),
256256
(3, 2, 4),
257257
torch.bool,
258258
0,
@@ -285,7 +285,7 @@ def forward(self, x):
285285
(
286286
"3d_dynamic_dims_float",
287287
(2, 1, 1),
288-
(2, 2, 1),
288+
(2, 2, 2),
289289
(3, 2, 4),
290290
torch.float,
291291
[1, 2],

tests/py/dynamo/conversion/test_arange_aten.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def forward(self, end_tensor):
5656
use_example_tensors=False,
5757
check_dtype=False,
5858
pyt_inputs=[pyt_input],
59+
use_dynamo_tracer=False,
5960
)
6061

6162

tests/py/dynamo/conversion/test_asin_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def forward(self, input):
6666
(
6767
"3d_dim_dtype_float",
6868
(1, 1, 1),
69-
(1, 2, 3),
69+
(2, 2, 3),
7070
(3, 3, 3),
7171
torch.float,
7272
torch.float,

tests/py/dynamo/conversion/test_asinh_aten.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ def forward(self, input):
5858
(
5959
"3d_dim_dtype_float",
6060
(1, 1, 1),
61-
(1, 2, 3),
61+
(2, 2, 3),
6262
(3, 3, 3),
6363
torch.float,
6464
torch.float,
6565
),
6666
(
6767
"3d_dim_dtype_int32",
6868
(1, 1, 1),
69-
(1, 2, 4),
69+
(2, 2, 4),
7070
(2, 3, 5),
7171
torch.int32,
7272
torch.float,

tests/py/dynamo/conversion/test_atan2_aten.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import unittest
2+
13
import torch
24
import torch.nn as nn
35
from parameterized import parameterized
@@ -141,15 +143,15 @@ def forward(self, lhs_val, rhs_val):
141143
(
142144
"3d_dim_dtype_float",
143145
(1, 1, 1),
144-
(1, 2, 3),
146+
(2, 2, 3),
145147
(3, 3, 3),
146148
torch.float,
147149
torch.float,
148150
),
149151
(
150152
"3d_dim_dtype_int32",
151153
(1, 1, 1),
152-
(1, 2, 4),
154+
(2, 2, 4),
153155
(2, 3, 5),
154156
torch.int32,
155157
torch.float,
@@ -182,10 +184,17 @@ def forward(self, lhs_val, rhs_val):
182184
)
183185

184186

187+
# torch.ops.aten.atan2.out will be decomposed/partitioned into core aten ops which torch_tensorrt supported in run_on_acc and
188+
# non supported ops in run_on_gpu in dynamo tracer, it works via torch_tensorrt.dynamo.compile workflow
189+
# but it won't be valid for our converter test framework, so skip it here.
190+
@unittest.skip("skip torch.ops.aten.atan2.out converter test")
185191
class TestAtan2OutConverter(DispatchTestCase):
186192
@parameterized.expand(
187193
[
188-
((10,), (5,), torch.float),
194+
# dynamo trace does not allow output to be in a different shape
195+
# raise Unsupported(msg, case_name=case_name)
196+
# torch._dynamo.exc.Unsupported: out variants with resizing on graph inputs
197+
((5,), (5,), torch.float),
189198
((10,), (10,), torch.float),
190199
]
191200
)
@@ -220,7 +229,7 @@ def forward(self, lhs_val, rhs_val, out):
220229
(
221230
"3d_dim_dtype_float",
222231
(1, 1, 1),
223-
(1, 2, 3),
232+
(2, 2, 3),
224233
(3, 3, 3),
225234
torch.float,
226235
torch.float,
@@ -255,7 +264,10 @@ def forward(self, lhs_val, rhs_val, out):
255264
),
256265
]
257266
self.run_test_with_dynamic_shape(
258-
atan2(), input_specs, output_dtypes=[output_type]
267+
atan2(),
268+
input_specs,
269+
output_dtypes=[output_type],
270+
use_dynamo_tracer=False,
259271
)
260272

261273

0 commit comments

Comments
 (0)