Skip to content

Commit 88f7895

Browse files
committed
fix: FakeTensors appearing in get_attr calls
- Register all constants as model parameters, which do not get fake-ified by the active FakeTensor context - Buffers and other constant registrations can be fake-ified, which is problematic for TRT tracing
1 parent fb07513 commit 88f7895

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
TRTInterpreterResult,
1313
)
1414
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
15-
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
15+
from torch_tensorrt.dynamo.utils import get_torch_inputs
1616

1717

1818
def interpret_module_to_result(
@@ -29,7 +29,6 @@ def interpret_module_to_result(
2929
TRTInterpreterResult
3030
"""
3131
torch_inputs = get_torch_inputs(inputs, settings.device)
32-
module.to(to_torch_device(settings.device))
3332
module_outputs = module(*torch_inputs)
3433

3534
if not isinstance(module_outputs, (list, tuple)):

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111

1212
# Modify import location of utilities based on Torch version
1313
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
14-
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
14+
from torch._inductor.freezing import ConstantFolder
1515
else:
16-
from torch._inductor.constant_folding import (
17-
ConstantFolder,
18-
replace_node_with_constant,
19-
)
16+
from torch._inductor.constant_folding import ConstantFolder
2017

2118
logger = logging.getLogger(__name__)
2219

@@ -36,7 +33,7 @@ def constant_fold(
3633
cf.run()
3734

3835
for node, constant in cf.node_replacements.items():
39-
replace_node_with_constant(gm, node, constant)
36+
replace_node_with_constant(gm, node, torch.nn.Parameter(constant.cuda()))
4037

4138
erased_params = []
4239
for node in gm.graph.nodes:
@@ -60,6 +57,40 @@ def constant_fold(
6057
return gm
6158

6259

60+
def replace_node_with_constant(
61+
gm: torch.fx.GraphModule, node: torch.fx.Node, constant: torch.Tensor
62+
) -> None:
63+
"""Adapted from:
64+
https://github.com/pytorch/pytorch/blob/bcf35c6ae62bb6560befa3550e37a8283944e5f4/torch/_inductor/constant_folding.py#L17-L43
65+
66+
Modified to register parameters, instead of buffers for frozen constants
67+
"""
68+
g = gm.graph
69+
70+
if not hasattr(gm, "_frozen_param_count"):
71+
gm._frozen_param_count = 0
72+
73+
i = gm._frozen_param_count
74+
75+
while True:
76+
qualname = f"_frozen_param{i}"
77+
if not hasattr(gm, qualname):
78+
break
79+
i += 1
80+
81+
gm._frozen_param_count = i + 1
82+
83+
with g.inserting_before(node):
84+
new_input_node = g.create_node("get_attr", qualname, (), {})
85+
node.replace_all_uses_with(new_input_node)
86+
new_input_node.meta.update(node.meta)
87+
g.erase_node(node)
88+
89+
# Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
90+
gm.register_parameter(qualname, constant)
91+
setattr(gm, qualname, constant)
92+
93+
6394
# TODO: Delete this class when the following code is fixed in nightly:
6495
# https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
6596
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]

0 commit comments

Comments
 (0)