Skip to content

Commit 51266db

Browse files
committed
feat: Add preliminary support for freezing tensors in Dynamo
fix: Refactor tensor freezing in Dynamo Key op fixes for failing tests
1 parent fe0d8e0 commit 51266db

File tree

8 files changed

+92
-52
lines changed

8 files changed

+92
-52
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

33
import logging
4-
from functools import partial
4+
import unittest
55
from typing import Any, Callable, Sequence
66

77
import torch
88
import torch._dynamo as td
9-
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
9+
from torch._dynamo.utils import detect_fake_mode
10+
from torch._functorch.aot_autograd import aot_export_joint_simple
1011
from torch_tensorrt.dynamo import CompilationSettings
1112
from torch_tensorrt.dynamo.compile import compile_module
1213
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
@@ -33,8 +34,7 @@ def torch_tensorrt_backend(
3334

3435
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3536

36-
compiled_mod: torch.nn.Module = DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
37-
return compiled_mod
37+
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
3838

3939

4040
@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
@@ -43,21 +43,26 @@ def aot_torch_tensorrt_aten_backend(
4343
) -> torch.nn.Module:
4444
settings = parse_dynamo_kwargs(kwargs)
4545

46-
custom_backend = partial(
47-
_pretraced_backend,
48-
settings=settings,
49-
)
50-
5146
# Perform Pre-AOT Lowering for Module-Level Replacement
5247
gm = pre_aot_substitutions(gm)
5348

54-
# Invoke AOTAutograd to translate operators to aten
55-
return aot_module_simplified(
56-
gm,
57-
sample_inputs,
58-
fw_compiler=make_boxed_compiler(custom_backend),
59-
decompositions=get_decompositions(settings.enable_experimental_decompositions),
60-
)
49+
fake_mode = detect_fake_mode(sample_inputs)
50+
51+
# Place backend tracing within FakeTensor context allowing nonfake Tensors
52+
with unittest.mock.patch.object(
53+
fake_mode, "allow_non_fake_inputs", True
54+
), fake_mode:
55+
# Invoke AOTAutograd to translate operators to aten
56+
graph_module = aot_export_joint_simple(
57+
gm,
58+
sample_inputs,
59+
trace_joint=False,
60+
decompositions=get_decompositions(
61+
settings.enable_experimental_decompositions
62+
),
63+
)
64+
65+
return _pretraced_backend(graph_module, sample_inputs, settings)
6166

6267

6368
def _pretraced_backend(

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from datetime import datetime
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

6-
import numpy
6+
import numpy as np
77

88
# @manual=//deeplearning/trt/python:py_tensorrt
99
import tensorrt as trt
1010
import torch
1111
import torch.fx
1212
from torch.fx.node import _get_qualified_name
1313
from torch.fx.passes.shape_prop import TensorMetadata
14+
from torch.utils._python_dispatch import _disable_current_modes
1415
from torch_tensorrt._Input import Input
1516
from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name
1617
from torch_tensorrt.fx.observer import Observer
@@ -169,7 +170,7 @@ def run(
169170

170171
cache = None
171172
if timing_cache:
172-
cache_file = numpy.array(timing_cache)
173+
cache_file = np.array(timing_cache)
173174
cache = builder_config.create_timing_cache(cache_file.tobytes())
174175
else:
175176
cache = builder_config.create_timing_cache(b"")
@@ -323,6 +324,21 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
323324
assert self._cur_node_name is not None
324325
return converter(self.network, target, args, kwargs, self._cur_node_name)
325326

327+
def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
328+
with _disable_current_modes():
329+
from torch_tensorrt.fx.converters import to_numpy
330+
331+
frozen_attr = self.fetch_attr(target)
332+
333+
if isinstance(frozen_attr, torch.nn.Parameter):
334+
constant_tensor = frozen_attr.data
335+
else:
336+
constant_tensor = frozen_attr
337+
338+
network_constant = to_numpy(constant_tensor)
339+
340+
return network_constant
341+
326342
def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
327343
assert isinstance(target, str)
328344
converter = CONVERTERS.get(self._cur_node)
@@ -344,6 +360,17 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
344360
else:
345361
outputs = (args[0],)
346362

363+
for output_idx in range(len(outputs)):
364+
from torch_tensorrt.fx.converters import get_trt_tensor
365+
366+
output = outputs[output_idx]
367+
368+
if not isinstance(output, trt.tensorrt.ITensor):
369+
new_output = get_trt_tensor(self.network, output, target)
370+
outputs = (
371+
outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :]
372+
)
373+
347374
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
348375
raise RuntimeError("TensorRT requires all outputs to be Tensor!")
349376

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, List, Optional, Sequence, Union, cast
33

44
import numpy as np
5+
import tensorrt as trt
56
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -19,8 +20,6 @@
1920
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
2021
from torch_tensorrt.fx.utils import get_dynamic_dims
2122

22-
import tensorrt as trt
23-
2423
_LOGGER: logging.Logger = logging.getLogger(__name__)
2524

2625

@@ -101,9 +100,15 @@ def layer_norm(
101100
"of the TensorRT region!"
102101
)
103102

104-
gamma = weight.detach().cpu().float().numpy()
103+
gamma = (
104+
weight.detach().cpu().float().numpy()
105+
if isinstance(weight, torch.Tensor)
106+
else weight
107+
)
105108
gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32)
106-
beta = bias.detach().cpu().float().numpy()
109+
beta = (
110+
bias.detach().cpu().float().numpy() if isinstance(bias, torch.Tensor) else bias
111+
)
107112
beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32)
108113
eps_field = trt.PluginField(
109114
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._decompositions import get_decompositions # noqa: F401
2+
from ._freeze_aot_graph import * # noqa: F401
23
from ._fusers import * # noqa: F401
34
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
45
from ._pre_aot_lowering import register_substitution # noqa: F401

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ def is_node_supported(
153153
) -> bool:
154154
node_name = ConverterRegistry.qualified_name_or_str(node.target)
155155

156-
if node in CONVERTERS and node_name not in self.torch_executed_ops:
156+
if (
157+
node.target in CONVERTERS.keys()
158+
or (node.op == "get_attr" and "constant" in node_name)
159+
) and node_name not in self.torch_executed_ops:
157160
# If node is a proper, supported computational node, store the operator
158161
if not node.is_impure():
159162
if node_name not in self.supported_operators:

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,27 @@
33
import math
44
import operator
55
import warnings
6-
from typing import cast, Dict, Optional, Sequence, Tuple, Union
6+
from typing import Dict, Optional, Sequence, Tuple, Union, cast
77

88
import numpy as np
99

1010
# @manual=//deeplearning/trt/python:py_tensorrt
1111
import tensorrt as trt
1212
import torch
13-
14-
from ..converter_registry import tensorrt_converter
15-
16-
from ..tracer.acc_tracer import acc_ops
17-
from ..types import * # noqa: F403
1813
from torch.fx.immutable_collections import immutable_list
1914
from torch.fx.node import Argument, Target
20-
21-
from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks
22-
23-
from .converter_utils import * # noqa: F403
15+
from torch_tensorrt.fx.converters.impl import activation, convolution
2416
from torch_tensorrt.fx.passes.lower_basic_pass import (
2517
trt_transposed_linear,
2618
trt_transposed_matmul,
2719
)
2820
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
29-
from torch_tensorrt.fx.converters.impl import activation, convolution
21+
22+
from ..converter_registry import tensorrt_converter
23+
from ..tracer.acc_tracer import acc_ops
24+
from ..types import * # noqa: F403
25+
from ..utils import Frameworks, get_dynamic_dims, unified_dtype_converter
26+
from .converter_utils import * # noqa: F403
3027

3128
_LOGGER: logging.Logger = logging.getLogger(__name__)
3229

@@ -2714,8 +2711,14 @@ def acc_ops_linear(
27142711
"dim for linear and it can't be the last dim."
27152712
)
27162713

2717-
if isinstance(kwargs["weight"], torch.Tensor):
2718-
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
2714+
if isinstance(kwargs["weight"], (torch.Tensor, np.ndarray)):
2715+
weight = get_trt_tensor(
2716+
network,
2717+
kwargs["weight"].t()
2718+
if isinstance(kwargs["weight"], torch.Tensor)
2719+
else kwargs["weight"].T,
2720+
f"{name}_weight",
2721+
)
27192722
if target not in (acc_ops.linear, torch.ops.aten.linear):
27202723
weight_op = trt.MatrixOperation.TRANSPOSE
27212724
else:

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import operator
22
import warnings
3+
from enum import Enum, auto
34
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
45

5-
from enum import Enum, auto
66
import numpy as np
77

88
# @manual=//deeplearning/trt/python:py_tensorrt
@@ -20,7 +20,7 @@
2020
TRTPluginFieldCollection,
2121
TRTTensor,
2222
)
23-
from ..utils import unified_dtype_converter, Frameworks
23+
from ..utils import Frameworks, unified_dtype_converter
2424

2525

2626
class SourceIR(Enum):
@@ -271,7 +271,7 @@ def create_constant(
271271
"""
272272
constant = network.add_constant(
273273
(1,) if isinstance(value, (int, float)) else value.shape,
274-
to_numpy(value, dtype),
274+
to_numpy(value, dtype).copy(),
275275
)
276276
constant.name = name
277277
return constant.get_output(0)
@@ -311,7 +311,7 @@ def get_trt_tensor(
311311
elif isinstance(input_val, np.ndarray) and (
312312
input_val.dtype == np.bool_ or input_val.dtype == np.int64
313313
):
314-
input_val = input_val.to(np.int32)
314+
input_val = input_val.astype(np.int32)
315315

316316
if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)):
317317
return create_constant(network, input_val, name, dtype)

py/torch_tensorrt/fx/converters/impl/convolution.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,23 @@
1-
import numpy as np
21
from typing import Any, Optional, Sequence, Union
32

3+
import numpy as np
4+
45
# @manual=//deeplearning/trt/python:py_tensorrt
56
import tensorrt as trt
67
import torch
78
from torch.fx.node import Target
8-
9+
from torch_tensorrt.fx.converters import acc_ops_converters
910
from torch_tensorrt.fx.converters.converter_utils import (
1011
SourceIR,
1112
extend_attr_to_tuple,
1213
get_dyn_range,
14+
get_trt_tensor,
15+
has_dynamic_shape,
1316
mark_as_int8_layer,
1417
set_layer_name,
15-
has_dynamic_shape,
1618
to_numpy,
17-
get_trt_tensor,
18-
)
19-
from torch_tensorrt.fx.converters import acc_ops_converters
20-
21-
from torch_tensorrt.fx.types import (
22-
TRTNetwork,
23-
TRTTensor,
2419
)
20+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
2521

2622

2723
def convNd(
@@ -54,7 +50,7 @@ def convNd(
5450
)
5551

5652
# Process bias terms
57-
if isinstance(bias, torch.Tensor):
53+
if isinstance(bias, (torch.Tensor, np.ndarray)):
5854
# Transform the bias constant into a Numpy array
5955
bias = to_numpy(bias)
6056

@@ -79,7 +75,7 @@ def convNd(
7975
network, target, tuple(), kwargs, name + "_unsqueeze_weight"
8076
)
8177

82-
elif isinstance(weight, torch.Tensor):
78+
elif isinstance(weight, (torch.Tensor, np.ndarray)):
8379
# Transform the weight constant into a Numpy array
8480
weight = to_numpy(weight)
8581

0 commit comments

Comments
 (0)