Skip to content

feat: Add sample torch.compile backend for tensorrt aten path #1751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 13, 2023

Conversation

gs-olive
Copy link
Collaborator

@gs-olive gs-olive commented Mar 20, 2023

Description

  • Add backend adapted from previous fx2trt_compiler provided by Dynamo
  • Currently, the TRTSplitter needs work to fully support the aten path
  • Additionally, the existing aten pass was reworked to exclude the torch._dynamo.export call, which may be necessary here

Note: The implementation is just a sample and is for use in testing - it has very few customization options (no precision specifications, etc.). Additionally, some of its components require modification for correct usage overall. Specifically, the TRTSplitter may need modifications to work with the aten path effectively

Note: The implementation was adapted from the earlier fx2trt Dynamo backend

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive added the WIP Work is in progress, pull request should not be merged yet label Mar 20, 2023
@github-actions github-actions bot requested a review from narendasan March 20, 2023 22:01
@gs-olive gs-olive self-assigned this Mar 20, 2023
@gs-olive
Copy link
Collaborator Author

gs-olive commented Mar 21, 2023

Proposed UX

from torch_tensorrt.fx.tracer.dispatch_tracer.tensorrt_dynamo_backend import fx2trt, tensorrt_backend
import torch._dynamo as torchdynamo
import torch

torchdynamo.optimize(fx2trt, nopython=True)(model)

##### OR #####

torch.compile(model, backend=tensorrt_backend)

Both of the above can handle dynamic data-dependent control flow

@gs-olive
Copy link
Collaborator Author

Notes on FakeTensor Issues During Compilation

During compilation, we sometimes encounter errors like this:

Calling to_numpy() on a FakeTensor
  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 328, in call_function
    return converter(self.network, target, args, kwargs, self._cur_node_name)
  File "~/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py", line 313, in aten_ops_sub
    return add_sub(network, target, kwargs_new, name)
  File "~/TensorRT/py/torch_tensorrt/fx/converters/operator.py", line 876, in add_sub
    return add_binary_elementwise_layer(
  File "~/TensorRT/py/torch_tensorrt/fx/converters/operator.py", line 141, in add_binary_elementwise_layer
    lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
  File "~/TensorRT/py/torch_tensorrt/fx/converters/converter_utils.py", line 237, in get_trt_tensor
    return create_constant(network, input_val, name, dtype)
  File "~/TensorRT/py/torch_tensorrt/fx/converters/converter_utils.py", line 201, in create_constant
    constant = network.add_constant(value.shape, to_numpy(value))
  File "~/TensorRT/py/torch_tensorrt/fx/converters/converter_utils.py", line 497, in to_numpy
    return tensor.cpu().detach().contiguous().numpy()
RuntimeError: .numpy() is not supported for tensor subclasses.

While executing %sub : [#users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1.0, %convert_element_type), kwargs = {_itensor_to_tensor_meta: {<tensorrt.tensorrt.ITensor object at 0x7f7fd13ce8f0>: None}})
Providing FakeTensors as input to Graph Modules
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 266, in tree_map_only
    return tree_map(map_only(ty)(fn), pytree)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 196, in tree_map
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 196, in <listcomp>
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 247, in inner
    return f(x)
  File "/usr/local/lib/python3.8/dist-packages/torch/_subclasses/fake_tensor.py", line 1282, in validate
    raise Exception(
Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]],
       device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(..., device='meta', size=(1, 14), dtype=torch.int32), cuda:0), 0), **{}) 

Solutions

There are two key functions/utilities which can help in these cases:
1. from torch._dynamo.backends.common import fake_tensor_unsupported:

@fake_tensor_unsupported
def fx2trt_compiler

2. from torch._subclasses import FakeTensorMode

with FakeTensorMode(allow_non_fake_inputs=True):
    ...

It remains to be seen whether these will be able to resolve the issues described above, or if they address the root cause of those issues. Additionally, we should consider whether it is reasonable to call torch._dynamo.export from within a call to torch.compile or torch._dynamo.optimze.

- Add backend adapted from previous `fx2trt_compiler` provided by
Dynamo
- Currently, the TRTSplitter needs work to fully support the `aten` path
- Additionally, the existing `aten` pass was reworked to exclude the
`torch._dynamo.export` call, which may be necessary here


@td.register_backend
@fake_tensor_unsupported
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frank-wei Wondering what your thoughts are on using @fake_tensor_unsupported here to avoid the two errors mentioned in #1751 (comment)? Could there be adverse effects on symbolic tracing/Dynamo?

@frank-wei
Copy link
Contributor

My understanding for faketensor is that it is an empty tensor and only used for shape and dtype induction. It is supposed to be used in aten tracer to help shape induction. Why do we need to use it directly?

@gs-olive gs-olive changed the base branch from main to dynamo_aten_backend March 31, 2023 16:22
gs-olive added a commit that referenced this pull request Mar 31, 2023
gs-olive added 2 commits April 5, 2023 22:49
- Update implementation to use Dynamo partition functionality
- Update implementation to use Dynamo decompositions to replace inplace
operators
- Name backends using standard names
- Add documentation, print statements, and helper functions to the code
@gs-olive
Copy link
Collaborator Author

gs-olive commented Apr 6, 2023

@frank-wei The reason I think using @fake_tensor_unsupported is necessary here is that if we don't use it, then the frequent calls to the create_constant function in converter/converter_utils.py will fail since they attempt to instantiate a constant tensor, but the FakeTensorMode disallows creation of any non-fake tensors within its context. This leads to the error shown in the first dropdown here: #1751 (comment).

gs-olive added 3 commits April 7, 2023 10:59
- Improve overall documentation and commenting, improve code delineation
and separation of functionality
- Add dedicated settings and defaults files to centralize data and
improve code readability, as well as reduce duplication of code
- Improve documentation of functions, types, and comments
- Rework logic to make compiler more uniform with existing torch
tensorrt compilers, while retaining key Dynamo keywords needed for
compilation via the torch.compile path
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Centralized defaults in _defaults.py and made frozen dataclass to store key compilation settings for easier readability + adaptability as new features are added to both Dynamo + Torch-TensorRT

- Improve overall functionality, fix bugs
- Move functions into __init__.py
- Improve overall documentation, comments, function header typing, and
code organization
- Add support for Input objects, add utilities
- Add modeling e2e test cases for Dynamo backend
- Improve defaults and settings in Dynamo class
@gs-olive gs-olive force-pushed the sample_backend branch 4 times, most recently from fefd1f9 to ef608e5 Compare April 12, 2023 02:54
@gs-olive gs-olive changed the base branch from dynamo_aten_backend to dynamo_changes April 12, 2023 20:18
@gs-olive gs-olive requested a review from peri044 April 12, 2023 21:17
@gs-olive gs-olive marked this pull request as ready for review April 12, 2023 21:18
@gs-olive gs-olive removed the WIP Work is in progress, pull request should not be merged yet label Apr 12, 2023
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@peri044 peri044 merged commit 33255de into pytorch:dynamo_changes Apr 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants