Skip to content

Commit a646e59

Browse files
authored
fix: Repair issue in Torch Constant Folder (#2375)
1 parent 83176fe commit a646e59

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Sequence
2+
from typing import Any, Sequence
33

44
import torch
55
from torch_tensorrt._utils import sanitized_torch_version
@@ -32,7 +32,7 @@ def constant_fold(
3232
3333
Modifies the graph in-place and replaces node with constants
3434
"""
35-
cf = ConstantFolder(gm, skip_constructors=False)
35+
cf = _TorchTensorRTConstantFolder(gm, skip_constructors=False)
3636
cf.run()
3737

3838
for node, constant in cf.node_replacements.items():
@@ -58,3 +58,14 @@ def constant_fold(
5858
logger.debug(f"Graph after constant folding:\n{gm.graph}")
5959

6060
return gm
61+
62+
63+
# TODO: Delete this class when the following code is fixed in nightly:
64+
# https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
65+
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
66+
def __init__(self, *args: Any, **kwargs: Any) -> None:
67+
super().__init__(*args, **kwargs)
68+
69+
# TODO: Update this function when quantization is added
70+
def is_impure(self, node: torch.fx.node.Node) -> bool:
71+
return False

0 commit comments

Comments
 (0)