Skip to content

Commit 28c9548

Browse files
Erik-Lundellfacebook-github-bot
authored andcommitted
Add CastInt64ToInt32Pass (#5842)
Summary: Integer placeholders, for example in this examples x = torch.Tensor([1.]) + 1 Are by default lowered as int64 tensors. As int64 is not a valid TOSA dtype, we cast these tensors to int32. Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_supported_number_formats Change-Id: Ice59625b7fd68ff3a544ee4648fe05f69b9333e6 Pull Request resolved: #5842 Reviewed By: mergennachin Differential Revision: D64047286 Pulled By: digantdesai fbshipit-source-id: 66cce1c5e360082cbe168d45ef2ba9ee492e8787
1 parent e540bcb commit 28c9548

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

backends/arm/passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.backends.arm.passes.annotate_channels_last_dim_order_pass import (
1212
AnnotateChannelsLastDimOrder,
1313
)
14+
from executorch.backends.arm.passes.cast_int64_pass import CastInt64ToInt32Pass
1415
from executorch.backends.arm.passes.convert_expand_copy_to_repeat import (
1516
ConvertExpandCopyToRepeatPass,
1617
)
@@ -40,6 +41,7 @@ def transform_to_backend_pipeline(
4041
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
4142
):
4243
"""Apply passes before transforming program to backend"""
44+
self.add_pass(CastInt64ToInt32Pass(exported_program))
4345
self.add_pass(SizeAdjustConv2DPass())
4446
self.add_pass(RemoveClonePass())
4547
self.add_pass(ConvertExpandCopyToRepeatPass())
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass, PassResult
8+
9+
10+
class CastInt64ToInt32Pass(ExportPass):
11+
def __init__(self, exported_program: torch.export.ExportedProgram):
12+
super(CastInt64ToInt32Pass, self).__init__()
13+
self.exported_program = exported_program
14+
15+
def _to_int32(self, graph_module: torch.fx.GraphModule):
16+
for node in graph_module.graph.nodes:
17+
fake_tensor = node.meta["val"]
18+
if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
19+
if node.meta["val"].dtype == torch.int64:
20+
node.meta["val"] = node.meta["val"].to(torch.int32)
21+
buffer_name = (
22+
self.exported_program.graph_signature.inputs_to_buffers[
23+
node.name
24+
]
25+
)
26+
new_tensor = self.exported_program.state_dict[buffer_name].to(
27+
torch.int32
28+
)
29+
self.exported_program.state_dict[buffer_name] = new_tensor
30+
31+
def call(self, graph_module: torch.fx.GraphModule):
32+
self._to_int32(graph_module)
33+
graph_module.recompile()
34+
graph_module = super().call(graph_module).graph_module
35+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)