|
14 | 14 | from torch_tensorrt._Device import Device
|
15 | 15 | from torch_tensorrt._enums import dtype
|
16 | 16 | from torch_tensorrt.dynamo import _defaults
|
| 17 | +from torch_tensorrt.dynamo._defaults import default_device |
17 | 18 | from torch_tensorrt.dynamo._settings import CompilationSettings
|
| 19 | +from torch_tensorrt.dynamo._tracer import get_dynamic_shapes_args |
18 | 20 |
|
19 | 21 | # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
|
20 | 22 | from torch_tensorrt.dynamo.conversion import TRTInterpreter
|
| 23 | +from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes |
21 | 24 | from torch_tensorrt.dynamo.lowering import (
|
22 | 25 | get_decompositions,
|
23 | 26 | post_lowering,
|
|
29 | 32 | _LOGGER: logging.Logger = logging.getLogger(__name__)
|
30 | 33 |
|
31 | 34 |
|
| 35 | +# this method is only used in our converter test to infer the module output dtypes via dummy inference |
| 36 | +# which is due to fx.symbolic_trace does not have the meta['val'] info in the node |
| 37 | +# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace |
| 38 | +def infer_module_output_dtypes_for_test( |
| 39 | + module: torch.fx.GraphModule, |
| 40 | + inputs: Sequence[Input], |
| 41 | + device: Device, |
| 42 | + kwarg_inputs: Optional[dict[str, Any]] = None, |
| 43 | + truncate_double: bool = False, |
| 44 | +) -> List[dtype]: |
| 45 | + """ |
| 46 | + This function performs model inference to determine the output dtypes |
| 47 | + and truncates them accordingly. inputs can be either arg_inputs or flattened input list. |
| 48 | + If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. |
| 49 | + """ |
| 50 | + # TODO: We can also determine output dtypes from the module.graph based on node metadata. |
| 51 | + # However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata, |
| 52 | + # so we stick to the model inference approach currently. |
| 53 | + with unset_fake_temporarily(): |
| 54 | + # Get the device on which the model exists |
| 55 | + # For large models, this can be done on CPU to save GPU memory allocation for TRT. |
| 56 | + device = get_model_device(module) |
| 57 | + torch_inputs = get_torch_inputs(inputs, device) |
| 58 | + if kwarg_inputs is None: |
| 59 | + kwarg_inputs = {} |
| 60 | + torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) |
| 61 | + module_outputs = module(*torch_inputs, **torch_kwarg_inputs) |
| 62 | + if not isinstance(module_outputs, (list, tuple)): |
| 63 | + module_outputs = [module_outputs] |
| 64 | + |
| 65 | + # Int64 outputs can sometimes be generated from within other operators |
| 66 | + # such as aten.sum - such outputs can be truncated |
| 67 | + output_dtypes = [] |
| 68 | + for output in module_outputs: |
| 69 | + output_ = output |
| 70 | + # We don't need to check if output is nested here because the input module will be flattened |
| 71 | + if not isinstance(output, torch.Tensor): |
| 72 | + if isinstance(output, str): |
| 73 | + raise ValueError( |
| 74 | + f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)" |
| 75 | + ) |
| 76 | + else: |
| 77 | + output_ = torch.tensor(output) |
| 78 | + |
| 79 | + if truncate_double and output_.dtype == dtype.float64: |
| 80 | + output_dtypes.append(dtype.float32) |
| 81 | + else: |
| 82 | + output_dtypes.append(dtype._from(output_.dtype)) |
| 83 | + |
| 84 | + return output_dtypes |
| 85 | + |
| 86 | + |
| 87 | +# this is to enable dynamo tracer as True in the converter test files batch by batch |
| 88 | +def get_use_dynamo_tracer(use_dynamo_tracer: Any) -> bool: |
| 89 | + # if in our converter tests we specifically set use_dynamo_tracer field, honor it |
| 90 | + if use_dynamo_tracer is not None and isinstance(use_dynamo_tracer, bool): |
| 91 | + return use_dynamo_tracer |
| 92 | + # if in our converter tests, we did not specify use_dynamo_tracer field |
| 93 | + import inspect |
| 94 | + import os |
| 95 | + import re |
| 96 | + |
| 97 | + filename = os.path.basename(inspect.stack()[2].filename) |
| 98 | + # enable converter test files which starts with test_a*.py to use dynamo tracer |
| 99 | + pattern = re.compile("^test_([a])+") |
| 100 | + if pattern.match(filename): |
| 101 | + return True |
| 102 | + else: |
| 103 | + return False |
| 104 | + |
| 105 | + |
32 | 106 | # this method is only used in our converter test to infer the module output dtypes via dummy inference
|
33 | 107 | # which is due to fx.symbolic_trace does not have the meta['val'] info in the node
|
34 | 108 | # TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace
|
@@ -277,14 +351,26 @@ def generate_graph(
|
277 | 351 | enable_passes: bool,
|
278 | 352 | propagate_shapes: bool = False,
|
279 | 353 | settings: CompilationSettings = CompilationSettings(),
|
| 354 | + torch_export_dynamic_shapes: Optional[Any] = None, |
280 | 355 | ):
|
281 | 356 | mod = mod.eval()
|
282 | 357 | if use_dynamo_tracer:
|
283 |
| - exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs)) |
284 |
| - exported_program = pre_export_lowering(exported_program, settings) |
285 |
| - exported_program = exported_program.run_decompositions( |
286 |
| - get_decompositions(False) |
| 358 | + if torch_export_dynamic_shapes is None: |
| 359 | + torch_export_dynamic_shapes = get_dynamic_shapes_args( |
| 360 | + mod, original_inputs |
| 361 | + ) |
| 362 | + device = default_device() |
| 363 | + torch_export_inputs = get_torch_inputs(original_inputs, device) |
| 364 | + exported_program = torch.export.export( |
| 365 | + mod, |
| 366 | + tuple(torch_export_inputs), |
| 367 | + dynamic_shapes=torch_export_dynamic_shapes, |
287 | 368 | )
|
| 369 | + if enable_passes: |
| 370 | + exported_program = pre_export_lowering(exported_program, settings) |
| 371 | + exported_program = exported_program.run_decompositions( |
| 372 | + get_decompositions(False) |
| 373 | + ) |
288 | 374 | fx_module = exported_program.module()
|
289 | 375 | else:
|
290 | 376 | fx_module = torch.fx.symbolic_trace(mod)
|
@@ -313,13 +399,15 @@ def run_test(
|
313 | 399 | atol=ATOL,
|
314 | 400 | precision=dtype.f32,
|
315 | 401 | check_dtype=True,
|
316 |
| - use_dynamo_tracer=False, |
| 402 | + use_dynamo_tracer=None, |
317 | 403 | enable_passes=False,
|
318 | 404 | propagate_shapes=False,
|
319 | 405 | int32_reqd=False,
|
320 | 406 | make_refittable=False,
|
321 | 407 | ):
|
322 |
| - |
| 408 | + # TODO: lan to remove this and set use_dynamo_traccer to True by default |
| 409 | + # once all the converter test files are moved to use_dynamo_tracer |
| 410 | + use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer) |
323 | 411 | # Previous instance of the interpreter auto-casted 64-bit inputs
|
324 | 412 | # We replicate this behavior here
|
325 | 413 | compilation_settings = CompilationSettings(
|
@@ -366,12 +454,18 @@ def run_test(
|
366 | 454 |
|
367 | 455 | output_dtypes = None
|
368 | 456 | if check_dtype:
|
369 |
| - output_dtypes = infer_module_output_dtypes_for_test( |
370 |
| - mod, |
371 |
| - input_specs, |
372 |
| - compilation_settings.device, |
373 |
| - truncate_double=compilation_settings.truncate_double, |
374 |
| - ) |
| 457 | + if use_dynamo_tracer: |
| 458 | + output_dtypes = infer_module_output_dtypes( |
| 459 | + mod, |
| 460 | + truncate_double=compilation_settings.truncate_double, |
| 461 | + ) |
| 462 | + else: |
| 463 | + output_dtypes = infer_module_output_dtypes_for_test( |
| 464 | + mod, |
| 465 | + input_specs, |
| 466 | + compilation_settings.device, |
| 467 | + truncate_double=compilation_settings.truncate_double, |
| 468 | + ) |
375 | 469 |
|
376 | 470 | _LOGGER.debug(f"Compilation settings: {compilation_settings}")
|
377 | 471 | _LOGGER.debug(f"Inputs: {input_specs}")
|
@@ -441,37 +535,47 @@ def run_test_with_dynamic_shape(
|
441 | 535 | rtol=RTOL,
|
442 | 536 | atol=ATOL,
|
443 | 537 | output_dtypes=None,
|
444 |
| - use_dynamo_tracer=False, |
| 538 | + use_dynamo_tracer=None, |
445 | 539 | enable_passes=False,
|
446 | 540 | use_example_tensors=True,
|
447 | 541 | pyt_inputs=None,
|
448 | 542 | propagate_shapes=False,
|
449 | 543 | check_dtype=True,
|
450 | 544 | make_refittable=False,
|
| 545 | + torch_export_dynamic_shapes=None, |
451 | 546 | ):
|
| 547 | + # TODO: lan to remove this and set use_dynamo_traccer to True by default |
| 548 | + # once all the converter test files are moved to use_dynamo_tracer |
| 549 | + use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer) |
452 | 550 |
|
453 | 551 | # Previous instance of the interpreter auto-casted 64-bit inputs
|
454 | 552 | # We replicate this behavior here
|
455 | 553 | compilation_settings = CompilationSettings(
|
456 | 554 | truncate_double=True, make_refittable=make_refittable
|
457 | 555 | )
|
458 |
| - |
459 | 556 | mod = self.generate_graph(
|
460 | 557 | mod,
|
461 | 558 | input_specs,
|
462 | 559 | use_dynamo_tracer=use_dynamo_tracer,
|
463 | 560 | enable_passes=enable_passes,
|
464 | 561 | propagate_shapes=propagate_shapes,
|
465 | 562 | settings=compilation_settings,
|
| 563 | + torch_export_dynamic_shapes=torch_export_dynamic_shapes, |
466 | 564 | )
|
467 | 565 |
|
468 | 566 | if check_dtype:
|
469 |
| - output_dtypes = infer_module_output_dtypes_for_test( |
470 |
| - mod, |
471 |
| - input_specs, |
472 |
| - compilation_settings.device, |
473 |
| - truncate_double=compilation_settings.truncate_double, |
474 |
| - ) |
| 567 | + if use_dynamo_tracer: |
| 568 | + output_dtypes = infer_module_output_dtypes( |
| 569 | + mod, |
| 570 | + truncate_double=compilation_settings.truncate_double, |
| 571 | + ) |
| 572 | + else: |
| 573 | + output_dtypes = infer_module_output_dtypes_for_test( |
| 574 | + mod, |
| 575 | + input_specs, |
| 576 | + compilation_settings.device, |
| 577 | + truncate_double=compilation_settings.truncate_double, |
| 578 | + ) |
475 | 579 |
|
476 | 580 | interp = TRTInterpreter(
|
477 | 581 | mod,
|
|
0 commit comments