|
| 1 | +# Copyright 2024 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import torch |
| 7 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 8 | + |
| 9 | + |
| 10 | +class CastInt64ToInt32Pass(ExportPass): |
| 11 | + def __init__(self, exported_program: torch.export.ExportedProgram): |
| 12 | + super(CastInt64ToInt32Pass, self).__init__() |
| 13 | + self.exported_program = exported_program |
| 14 | + |
| 15 | + def _to_int32(self, graph_module: torch.fx.GraphModule): |
| 16 | + for node in graph_module.graph.nodes: |
| 17 | + fake_tensor = node.meta["val"] |
| 18 | + if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): |
| 19 | + if node.meta["val"].dtype == torch.int64: |
| 20 | + node.meta["val"] = node.meta["val"].to(torch.int32) |
| 21 | + buffer_name = ( |
| 22 | + self.exported_program.graph_signature.inputs_to_buffers[ |
| 23 | + node.name |
| 24 | + ] |
| 25 | + ) |
| 26 | + new_tensor = self.exported_program.state_dict[buffer_name].to( |
| 27 | + torch.int32 |
| 28 | + ) |
| 29 | + self.exported_program.state_dict[buffer_name] = new_tensor |
| 30 | + |
| 31 | + def call(self, graph_module: torch.fx.GraphModule): |
| 32 | + self._to_int32(graph_module) |
| 33 | + graph_module.recompile() |
| 34 | + graph_module = super().call(graph_module).graph_module |
| 35 | + return PassResult(graph_module, True) |
0 commit comments