5
5
6
6
# pyre-unsafe
7
7
8
+ import logging
9
+
8
10
import torch
11
+ from executorch .backends .arm ._passes .arm_pass_utils import is_param_node
9
12
from executorch .exir .pass_base import ExportPass , PassResult
13
+ from torch ._export .utils import is_buffer
14
+
15
+ logger = logging .getLogger (__name__ )
16
+ logger .setLevel (logging .WARNING )
10
17
11
18
12
19
class CastInt64ToInt32Pass (ExportPass ):
@@ -18,17 +25,31 @@ def _to_int32(self, graph_module: torch.fx.GraphModule):
18
25
for node in graph_module .graph .nodes :
19
26
fake_tensor = node .meta ["val" ]
20
27
if isinstance (fake_tensor , torch ._subclasses .fake_tensor .FakeTensor ):
21
- if node .meta ["val" ].dtype == torch .int64 :
22
- node .meta ["val" ] = node .meta ["val" ].to (torch .int32 )
23
- buffer_name = (
24
- self .exported_program .graph_signature .inputs_to_buffers [
25
- node .name
26
- ]
27
- )
28
- new_tensor = self .exported_program .state_dict [buffer_name ].to (
29
- torch .int32
30
- )
31
- self .exported_program .state_dict [buffer_name ] = new_tensor
28
+ if node .meta ["val" ].dtype == torch .int64 and is_param_node (
29
+ self .exported_program , node
30
+ ):
31
+ if is_buffer (self .exported_program , node ):
32
+ node .meta ["val" ] = node .meta ["val" ].to (torch .int32 )
33
+ buffer_name = (
34
+ self .exported_program .graph_signature .inputs_to_buffers [
35
+ node .name
36
+ ]
37
+ )
38
+ buffer = self .exported_program .state_dict [node .name ]
39
+ logger .warning (
40
+ f"Casting buffer { node .name } from torch.int64 to torch.int32"
41
+ f" defined in { node .meta ['stack_trace' ]} "
42
+ )
43
+ if torch .min (buffer ) < torch .iinfo (torch .int32 ).min :
44
+ raise RuntimeError (
45
+ f"Buffer { node .name } has value < { torch .iinfo (torch .int32 ).min } "
46
+ )
47
+ if torch .max (buffer ) > torch .iinfo (torch .int32 ).max :
48
+ raise RuntimeError (
49
+ f"Buffer { node .name } has value > { torch .iinfo (torch .int32 ).max } "
50
+ )
51
+ buffer_int32 = buffer .to (torch .int32 )
52
+ self .exported_program .state_dict [buffer_name ] = buffer_int32
32
53
33
54
def call (self , graph_module : torch .fx .GraphModule ):
34
55
self ._to_int32 (graph_module )
0 commit comments