Skip to content

fix: FakeTensors appearing in get_attr calls #2669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
TRTInterpreterResult,
)
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
from torch_tensorrt.dynamo.utils import get_torch_inputs


def interpret_module_to_result(
Expand All @@ -29,7 +29,6 @@ def interpret_module_to_result(
TRTInterpreterResult
"""
torch_inputs = get_torch_inputs(inputs, settings.device)
module.to(to_torch_device(settings.device))
module_outputs = module(*torch_inputs)

if not isinstance(module_outputs, (list, tuple)):
Expand Down
43 changes: 37 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@

# Modify import location of utilities based on Torch version
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
from torch._inductor.freezing import ConstantFolder
else:
from torch._inductor.constant_folding import (
ConstantFolder,
replace_node_with_constant,
)
from torch._inductor.constant_folding import ConstantFolder

logger = logging.getLogger(__name__)

Expand All @@ -36,7 +33,7 @@ def constant_fold(
cf.run()

for node, constant in cf.node_replacements.items():
replace_node_with_constant(gm, node, constant)
replace_node_with_constant(gm, node, torch.nn.Parameter(constant.cuda()))

erased_params = []
for node in gm.graph.nodes:
Expand All @@ -60,6 +57,40 @@ def constant_fold(
return gm


def replace_node_with_constant(
gm: torch.fx.GraphModule, node: torch.fx.Node, constant: torch.Tensor
) -> None:
"""Adapted from:
https://github.com/pytorch/pytorch/blob/bcf35c6ae62bb6560befa3550e37a8283944e5f4/torch/_inductor/constant_folding.py#L17-L43

Modified to register parameters, instead of buffers for frozen constants
"""
g = gm.graph

if not hasattr(gm, "_frozen_param_count"):
gm._frozen_param_count = 0

i = gm._frozen_param_count

while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1

gm._frozen_param_count = i + 1

with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)

# Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_parameter(qualname, constant)
setattr(gm, qualname, constant)


# TODO: Delete this class when the following code is fixed in nightly:
# https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
Expand Down