Skip to content

Commit 29797d4

Browse files
authored
Add support for lifted tensors in ArmPartitioner
Differential Revision: D61542388 Pull Request resolved: #4788
1 parent ef9c07f commit 29797d4

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

backends/arm/operators/op_placeholder.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
import serializer.tosa_serializer as ts
8-
import torch
8+
import torch.fx
99
from executorch.backends.arm.tosa_mapping import TosaArg
1010
from executorch.backends.arm.tosa_quant_utils import (
1111
get_quant_arg_dtype,
@@ -130,6 +130,21 @@ def process_inputs_to_buffers(
130130
)
131131

132132

133+
def process_inputs_to_lifted_tensor_constants(
134+
node: torch.fx.Node,
135+
tosa_graph: ts.TosaSerializer,
136+
edge_program: ExportedProgram,
137+
):
138+
arg = TosaArg(node)
139+
tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[
140+
arg.name
141+
]
142+
tensor = edge_program.tensor_constants[tensor_name]
143+
tensor_data = tensor.detach().numpy()
144+
145+
tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name)
146+
147+
133148
def process_placeholder(
134149
node: torch.fx.Node,
135150
tosa_graph: ts.TosaSerializer,
@@ -145,5 +160,11 @@ def process_placeholder(
145160
process_inputs_to_parameters(node, tosa_graph, edge_program)
146161
elif node.name in edge_program.graph_signature.inputs_to_buffers:
147162
process_inputs_to_buffers(node, tosa_graph, edge_program)
163+
elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:
164+
process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program)
165+
elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs:
166+
raise NotImplementedError(
167+
"Placeholder is of type 'lifted custom object' which is not supported."
168+
)
148169
else:
149-
raise RuntimeError(f"Unknown placeholder {node.name}")
170+
raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
11+
12+
13+
class LiftedTensor(torch.nn.Module):
14+
15+
def __init__(self):
16+
super().__init__()
17+
self.lifted_tensor = torch.Tensor([[1, 2], [3, 4]])
18+
19+
def forward(self, x: torch.Tensor, length) -> torch.Tensor:
20+
sliced = self.lifted_tensor[:, :length]
21+
return sliced + x
22+
23+
24+
class TestLiftedTensor(unittest.TestCase):
25+
"""Tests the ArmPartitioner with a placeholder of type lifted tensor."""
26+
27+
def test_partition_lifted_tensor(self):
28+
tester = (
29+
ArmTester(
30+
LiftedTensor(),
31+
example_inputs=(torch.ones(2, 2), 2),
32+
compile_spec=common.get_tosa_compile_spec(),
33+
)
34+
.export()
35+
.to_edge()
36+
.dump_artifact()
37+
)
38+
signature = tester.get_artifact().exported_program().graph_signature
39+
assert len(signature.lifted_tensor_constants) > 0
40+
tester.partition()
41+
tester.to_executorch()
42+
tester.run_method_and_compare_outputs((torch.ones(2, 2), 2))

0 commit comments

Comments
 (0)