Skip to content

Commit b9871c5

Browse files
authored
fix: Wrap import of ConstantFold utilities (#2312)
1 parent 783a760 commit b9871c5

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,27 @@
77
import torch
88
import torch._dynamo as td
99
import torch.utils._pytree as pytree
10+
import torch_tensorrt
1011
from torch._dynamo.utils import detect_fake_mode
1112
from torch._functorch.aot_autograd import _aot_export_function
12-
from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant
1313
from torch._ops import OpOverload
1414
from torch_tensorrt.dynamo import CompilationSettings
1515
from torch_tensorrt.dynamo.compile import compile_module
1616
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
1717
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1818
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1919

20+
from packaging import version
21+
22+
# Modify import location of utilities based on Torch version
23+
if version.parse(torch_tensorrt.sanitized_torch_version()) <= version.parse("2.1.0"):
24+
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
25+
else:
26+
from torch._inductor.constant_folding import (
27+
ConstantFolder,
28+
replace_node_with_constant,
29+
)
30+
2031
logger = logging.getLogger(__name__)
2132

2233

0 commit comments

Comments
 (0)