Skip to content

Commit dc8b6d7

Browse files
oscarandersson8218freddan80
authored andcommitted
Fix bug in ScalarsToAttributePass
The pass should not modify the scalar argument if output is non-float. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I36f6975e8d6f33e5834e44959f6e426808452de1
1 parent 2967302 commit dc8b6d7

File tree

3 files changed

+46
-11
lines changed

3 files changed

+46
-11
lines changed

backends/arm/_passes/cast_int64_pass.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,15 @@
55

66
# pyre-unsafe
77

8+
import logging
9+
810
import torch
11+
from executorch.backends.arm._passes.arm_pass_utils import is_param_node
912
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch._export.utils import is_buffer
14+
15+
logger = logging.getLogger(__name__)
16+
logger.setLevel(logging.WARNING)
1017

1118

1219
class CastInt64ToInt32Pass(ExportPass):
@@ -18,17 +25,31 @@ def _to_int32(self, graph_module: torch.fx.GraphModule):
1825
for node in graph_module.graph.nodes:
1926
fake_tensor = node.meta["val"]
2027
if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
21-
if node.meta["val"].dtype == torch.int64:
22-
node.meta["val"] = node.meta["val"].to(torch.int32)
23-
buffer_name = (
24-
self.exported_program.graph_signature.inputs_to_buffers[
25-
node.name
26-
]
27-
)
28-
new_tensor = self.exported_program.state_dict[buffer_name].to(
29-
torch.int32
30-
)
31-
self.exported_program.state_dict[buffer_name] = new_tensor
28+
if node.meta["val"].dtype == torch.int64 and is_param_node(
29+
self.exported_program, node
30+
):
31+
if is_buffer(self.exported_program, node):
32+
node.meta["val"] = node.meta["val"].to(torch.int32)
33+
buffer_name = (
34+
self.exported_program.graph_signature.inputs_to_buffers[
35+
node.name
36+
]
37+
)
38+
buffer = self.exported_program.state_dict[node.name]
39+
logger.warning(
40+
f"Casting buffer {node.name} from torch.int64 to torch.int32"
41+
f" defined in {node.meta['stack_trace']}"
42+
)
43+
if torch.min(buffer) < torch.iinfo(torch.int32).min:
44+
raise RuntimeError(
45+
f"Buffer {node.name} has value < {torch.iinfo(torch.int32).min}"
46+
)
47+
if torch.max(buffer) > torch.iinfo(torch.int32).max:
48+
raise RuntimeError(
49+
f"Buffer {node.name} has value > {torch.iinfo(torch.int32).max}"
50+
)
51+
buffer_int32 = buffer.to(torch.int32)
52+
self.exported_program.state_dict[buffer_name] = buffer_int32
3253

3354
def call(self, graph_module: torch.fx.GraphModule):
3455
self._to_int32(graph_module)

backends/arm/_passes/scalars_to_attribute_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
5151
if isinstance(arg, Node):
5252
new_args.append(arg)
5353
continue
54+
if isinstance(arg, int) and not torch.is_floating_point(
55+
get_first_fake_tensor(n)
56+
):
57+
new_args.append(arg)
58+
continue
5459

5560
prefix = "_tensor_constant_"
5661
get_new_attr_name = get_new_attr_name_with_prefix(prefix)

backends/arm/test/ops/test_scalars.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ def forward(self, x):
7575
x = 1.0 + x
7676
return x
7777

78+
class ShiftInplaceSub(torch.nn.Module):
79+
def forward(self, x):
80+
x = x >> 4
81+
x -= 10
82+
return x
83+
7884
# Inplace ops end with '_' (from aten naming)
7985
ops = [
8086
("Add", Add()),
@@ -160,3 +166,6 @@ def test_MI_const(self, test_name: str, op: torch.nn.Module, x):
160166
@parameterized.expand(tensor_scalar_tests)
161167
def test_BI(self, test_name: str, op: torch.nn.Module, x, y):
162168
self._test_add_tosa_BI_pipeline(op, (x, y))
169+
170+
def test_shift_sub_inplace_tosa_MI(self):
171+
self._test_add_tosa_MI_pipeline(self.ShiftInplaceSub(), (torch.IntTensor(5),))

0 commit comments

Comments
 (0)