File tree Expand file tree Collapse file tree 1 file changed +13
-2
lines changed
py/torch_tensorrt/dynamo/lowering/passes Expand file tree Collapse file tree 1 file changed +13
-2
lines changed Original file line number Diff line number Diff line change 1
1
import logging
2
- from typing import Sequence
2
+ from typing import Any , Sequence
3
3
4
4
import torch
5
5
from torch_tensorrt ._utils import sanitized_torch_version
@@ -32,7 +32,7 @@ def constant_fold(
32
32
33
33
Modifies the graph in-place and replaces node with constants
34
34
"""
35
- cf = ConstantFolder (gm , skip_constructors = False )
35
+ cf = _TorchTensorRTConstantFolder (gm , skip_constructors = False )
36
36
cf .run ()
37
37
38
38
for node , constant in cf .node_replacements .items ():
@@ -58,3 +58,14 @@ def constant_fold(
58
58
logger .debug (f"Graph after constant folding:\n { gm .graph } " )
59
59
60
60
return gm
61
+
62
+
63
+ # TODO: Delete this class when the following code is fixed in nightly:
64
+ # https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
65
+ class _TorchTensorRTConstantFolder (ConstantFolder ): # type: ignore[misc]
66
+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
67
+ super ().__init__ (* args , ** kwargs )
68
+
69
+ # TODO: Update this function when quantization is added
70
+ def is_impure (self , node : torch .fx .node .Node ) -> bool :
71
+ return False
You can’t perform that action at this time.
0 commit comments