11
11
12
12
# Modify import location of utilities based on Torch version
13
13
if version .parse (sanitized_torch_version ()) < version .parse ("2.1.1" ):
14
- from torch ._inductor .freezing import ConstantFolder , replace_node_with_constant
14
+ from torch ._inductor .freezing import ConstantFolder
15
15
else :
16
- from torch ._inductor .constant_folding import (
17
- ConstantFolder ,
18
- replace_node_with_constant ,
19
- )
16
+ from torch ._inductor .constant_folding import ConstantFolder
20
17
21
18
logger = logging .getLogger (__name__ )
22
19
@@ -36,7 +33,9 @@ def constant_fold(
36
33
cf .run ()
37
34
38
35
for node , constant in cf .node_replacements .items ():
39
- replace_node_with_constant (gm , node , constant )
36
+ replace_node_with_constant (
37
+ gm , node , torch .nn .Parameter (constant .cuda (), requires_grad = False )
38
+ )
40
39
41
40
erased_params = []
42
41
for node in gm .graph .nodes :
@@ -55,6 +54,40 @@ def constant_fold(
55
54
return gm
56
55
57
56
57
+ def replace_node_with_constant (
58
+ gm : torch .fx .GraphModule , node : torch .fx .Node , constant : torch .Tensor
59
+ ) -> None :
60
+ """Adapted from:
61
+ https://github.com/pytorch/pytorch/blob/bcf35c6ae62bb6560befa3550e37a8283944e5f4/torch/_inductor/constant_folding.py#L17-L43
62
+
63
+ Modified to register parameters, instead of buffers for frozen constants
64
+ """
65
+ g = gm .graph
66
+
67
+ if not hasattr (gm , "_frozen_param_count" ):
68
+ gm ._frozen_param_count = 0
69
+
70
+ i = gm ._frozen_param_count
71
+
72
+ while True :
73
+ qualname = f"_frozen_param{ i } "
74
+ if not hasattr (gm , qualname ):
75
+ break
76
+ i += 1
77
+
78
+ gm ._frozen_param_count = i + 1
79
+
80
+ with g .inserting_before (node ):
81
+ new_input_node = g .create_node ("get_attr" , qualname , (), {})
82
+ node .replace_all_uses_with (new_input_node )
83
+ new_input_node .meta .update (node .meta )
84
+ g .erase_node (node )
85
+
86
+ # Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
87
+ gm .register_parameter (qualname , constant )
88
+ setattr (gm , qualname , constant )
89
+
90
+
58
91
# TODO: Delete this class when the following code is fixed in nightly:
59
92
# https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
60
93
class _TorchTensorRTConstantFolder (ConstantFolder ): # type: ignore[misc]
0 commit comments