Skip to content

Commit 766c270

Browse files
authored
fix: FakeTensors appearing in get_attr calls (#2669)
1 parent 76c9ebd commit 766c270

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
TRTInterpreterResult,
1313
)
1414
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
15-
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
15+
from torch_tensorrt.dynamo.utils import get_torch_inputs
1616

1717

1818
def interpret_module_to_result(
@@ -29,7 +29,6 @@ def interpret_module_to_result(
2929
TRTInterpreterResult
3030
"""
3131
torch_inputs = get_torch_inputs(inputs, settings.device)
32-
module.to(to_torch_device(settings.device))
3332
module_outputs = module(*torch_inputs)
3433

3534
if not isinstance(module_outputs, (list, tuple)):

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111

1212
# Modify import location of utilities based on Torch version
1313
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
14-
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
14+
from torch._inductor.freezing import ConstantFolder
1515
else:
16-
from torch._inductor.constant_folding import (
17-
ConstantFolder,
18-
replace_node_with_constant,
19-
)
16+
from torch._inductor.constant_folding import ConstantFolder
2017

2118
logger = logging.getLogger(__name__)
2219

@@ -36,7 +33,9 @@ def constant_fold(
3633
cf.run()
3734

3835
for node, constant in cf.node_replacements.items():
39-
replace_node_with_constant(gm, node, constant)
36+
replace_node_with_constant(
37+
gm, node, torch.nn.Parameter(constant.cuda(), requires_grad=False)
38+
)
4039

4140
erased_params = []
4241
for node in gm.graph.nodes:
@@ -55,6 +54,40 @@ def constant_fold(
5554
return gm
5655

5756

57+
def replace_node_with_constant(
58+
gm: torch.fx.GraphModule, node: torch.fx.Node, constant: torch.Tensor
59+
) -> None:
60+
"""Adapted from:
61+
https://github.com/pytorch/pytorch/blob/bcf35c6ae62bb6560befa3550e37a8283944e5f4/torch/_inductor/constant_folding.py#L17-L43
62+
63+
Modified to register parameters, instead of buffers for frozen constants
64+
"""
65+
g = gm.graph
66+
67+
if not hasattr(gm, "_frozen_param_count"):
68+
gm._frozen_param_count = 0
69+
70+
i = gm._frozen_param_count
71+
72+
while True:
73+
qualname = f"_frozen_param{i}"
74+
if not hasattr(gm, qualname):
75+
break
76+
i += 1
77+
78+
gm._frozen_param_count = i + 1
79+
80+
with g.inserting_before(node):
81+
new_input_node = g.create_node("get_attr", qualname, (), {})
82+
node.replace_all_uses_with(new_input_node)
83+
new_input_node.meta.update(node.meta)
84+
g.erase_node(node)
85+
86+
# Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
87+
gm.register_parameter(qualname, constant)
88+
setattr(gm, qualname, constant)
89+
90+
5891
# TODO: Delete this class when the following code is fixed in nightly:
5992
# https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
6093
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]

0 commit comments

Comments
 (0)