-
Notifications
You must be signed in to change notification settings - Fork 363
feat: Add sample torch.compile backend for tensorrt aten path #1751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Proposed UXfrom torch_tensorrt.fx.tracer.dispatch_tracer.tensorrt_dynamo_backend import fx2trt, tensorrt_backend
import torch._dynamo as torchdynamo
import torch
torchdynamo.optimize(fx2trt, nopython=True)(model)
##### OR #####
torch.compile(model, backend=tensorrt_backend) Both of the above can handle dynamic data-dependent control flow |
Notes on FakeTensor Issues During CompilationDuring compilation, we sometimes encounter errors like this: Calling to_numpy() on a FakeTensor File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 328, in call_function
return converter(self.network, target, args, kwargs, self._cur_node_name)
File "~/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py", line 313, in aten_ops_sub
return add_sub(network, target, kwargs_new, name)
File "~/TensorRT/py/torch_tensorrt/fx/converters/operator.py", line 876, in add_sub
return add_binary_elementwise_layer(
File "~/TensorRT/py/torch_tensorrt/fx/converters/operator.py", line 141, in add_binary_elementwise_layer
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
File "~/TensorRT/py/torch_tensorrt/fx/converters/converter_utils.py", line 237, in get_trt_tensor
return create_constant(network, input_val, name, dtype)
File "~/TensorRT/py/torch_tensorrt/fx/converters/converter_utils.py", line 201, in create_constant
constant = network.add_constant(value.shape, to_numpy(value))
File "~/TensorRT/py/torch_tensorrt/fx/converters/converter_utils.py", line 497, in to_numpy
return tensor.cpu().detach().contiguous().numpy()
RuntimeError: .numpy() is not supported for tensor subclasses.
While executing %sub : [#users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1.0, %convert_element_type), kwargs = {_itensor_to_tensor_meta: {<tensorrt.tensorrt.ITensor object at 0x7f7fd13ce8f0>: None}}) Providing FakeTensors as input to Graph Modules File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 266, in tree_map_only
return tree_map(map_only(ty)(fn), pytree)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 196, in tree_map
return tree_unflatten([fn(i) for i in flat_args], spec)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 196, in <listcomp>
return tree_unflatten([fn(i) for i in flat_args], spec)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 247, in inner
return f(x)
File "/usr/local/lib/python3.8/dist-packages/torch/_subclasses/fake_tensor.py", line 1282, in validate
raise Exception(
Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[-0.0102, -0.0615, -0.0265, ..., -0.0199, -0.0372, -0.0098],
[-0.0117, -0.0600, -0.0323, ..., -0.0168, -0.0401, -0.0107],
[-0.0198, -0.0627, -0.0326, ..., -0.0165, -0.0420, -0.0032],
...,
[-0.0218, -0.0556, -0.0135, ..., -0.0043, -0.0151, -0.0249],
[-0.0462, -0.0565, -0.0019, ..., 0.0157, -0.0139, -0.0095],
[ 0.0015, -0.0821, -0.0160, ..., -0.0081, -0.0475, 0.0753]],
device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(..., device='meta', size=(1, 14), dtype=torch.int32), cuda:0), 0), **{}) SolutionsThere are two key functions/utilities which can help in these cases: @fake_tensor_unsupported
def fx2trt_compiler 2. with FakeTensorMode(allow_non_fake_inputs=True):
... It remains to be seen whether these will be able to resolve the issues described above, or if they address the root cause of those issues. Additionally, we should consider whether it is reasonable to call |
- Add backend adapted from previous `fx2trt_compiler` provided by Dynamo - Currently, the TRTSplitter needs work to fully support the `aten` path - Additionally, the existing `aten` pass was reworked to exclude the `torch._dynamo.export` call, which may be necessary here
36c95f8
to
6a8102c
Compare
|
||
|
||
@td.register_backend | ||
@fake_tensor_unsupported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@frank-wei Wondering what your thoughts are on using @fake_tensor_unsupported
here to avoid the two errors mentioned in #1751 (comment)? Could there be adverse effects on symbolic tracing/Dynamo?
My understanding for faketensor is that it is an empty tensor and only used for shape and dtype induction. It is supposed to be used in aten tracer to help shape induction. Why do we need to use it directly? |
- Update implementation to use Dynamo partition functionality - Update implementation to use Dynamo decompositions to replace inplace operators - Name backends using standard names - Add documentation, print statements, and helper functions to the code
@frank-wei The reason I think using |
- Improve overall documentation and commenting, improve code delineation and separation of functionality
- Add dedicated settings and defaults files to centralize data and improve code readability, as well as reduce duplication of code - Improve documentation of functions, types, and comments - Rework logic to make compiler more uniform with existing torch tensorrt compilers, while retaining key Dynamo keywords needed for compilation via the torch.compile path
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Centralized defaults in _defaults.py
and made frozen dataclass to store key compilation settings for easier readability + adaptability as new features are added to both Dynamo + Torch-TensorRT
- Improve overall functionality, fix bugs - Move functions into __init__.py - Improve overall documentation, comments, function header typing, and code organization
- Add support for Input objects, add utilities - Add modeling e2e test cases for Dynamo backend - Improve defaults and settings in Dynamo class
fefd1f9
to
ef608e5
Compare
ef608e5
to
226cc79
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
fx2trt_compiler
provided by Dynamoaten
pathaten
pass was reworked to exclude thetorch._dynamo.export
call, which may be necessary hereNote: The implementation is just a sample and is for use in testing - it has very few customization options (no precision specifications, etc.). Additionally, some of its components require modification for correct usage overall. Specifically, the
TRTSplitter
may need modifications to work with theaten
path effectivelyNote: The implementation was adapted from the earlier fx2trt Dynamo backend
Type of change
Checklist: