Skip to content

Commit 0c3e306

Browse files
committed
Update base for Update on "Use std::variant to implement pytree Key"
Key was a struct that should've been a union; std::variant makes using a union much easier. Differential Revision: [D65575184](https://our.internmc.facebook.com/intern/diff/D65575184/) [ghstack-poisoned]
2 parents 108116c + 39e5b91 commit 0c3e306

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+3636
-900
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
4444
UnsqueezeScalarPlaceholdersPass,
4545
)
46+
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
4647
from executorch.exir import ExportedProgram
4748
from executorch.exir.backend.compile_spec_schema import CompileSpec
4849
from executorch.exir.pass_manager import PassManager
@@ -58,6 +59,7 @@ def transform_to_backend_pipeline(
5859
):
5960
"""Apply passes before transforming program to backend"""
6061
self.add_pass(CastInt64ToInt32Pass(exported_program))
62+
self.add_pass(RemoveGetItemPass())
6163
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
6264
self.add_pass(SizeAdjustConv2DPass())
6365
self.add_pass(RemoveClonePass())

backends/arm/arm_backend.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import logging
1515
import os
16-
from typing import final, List, Optional
16+
from typing import cast, final, List, Optional
1717

1818
import serializer.tosa_serializer as ts
1919
from executorch.backends.arm.arm_vela import vela_compile
@@ -31,6 +31,7 @@
3131
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3232
from executorch.exir.backend.compile_spec_schema import CompileSpec
3333
from torch.export.exported_program import ExportedProgram
34+
from torch.fx import Node
3435

3536
# TOSA backend debug functionality
3637
logger = logging.getLogger(__name__)
@@ -225,6 +226,7 @@ def preprocess( # noqa: C901
225226
node_visitors = get_node_visitors(edge_program)
226227

227228
for node in graph_module.graph.nodes:
229+
node = cast(Node, node)
228230
if node.op == "call_function":
229231
process_call_function(node, tosa_graph, node_visitors)
230232
elif node.op == "placeholder":
@@ -236,9 +238,6 @@ def preprocess( # noqa: C901
236238
# any checking of compatibility.
237239
dbg_fail(node, tosa_graph, artifact_path)
238240

239-
# TODO: It would be awesome if this dump could somehow be done on top level and not here.
240-
# Problem is that the desc.json has to be created on the tosa_graph object, which we can't
241-
# access from top level.
242241
if artifact_path:
243242
tag = _get_first_delegation_tag(graph_module)
244243
dbg_tosa_dump(
@@ -259,6 +258,4 @@ def preprocess( # noqa: C901
259258
else:
260259
raise RuntimeError(f"Unknown format {output_format}")
261260

262-
# Continueing from above. Can I put tosa_graph into this function?
263-
# debug_handle_map = ...
264261
return PreprocessResult(processed_bytes=binary)

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5555
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
5656
exir_ops.edge.aten.native_layer_norm.default,
5757
exir_ops.edge.aten.avg_pool2d.default,
58+
exir_ops.edge.aten.max_pool2d_with_indices.default,
5859
exir_ops.edge.aten.sigmoid.default,
5960
exir_ops.edge.aten.mm.default,
6061
exir_ops.edge.aten.repeat.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
op_get_item,
2121
op_hardtanh,
2222
op_log,
23+
op_max_pool2d,
2324
op_mm,
2425
op_mul,
2526
op_permute,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import cast, List
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_utils import get_quant_node_args
17+
18+
from serializer.tosa_serializer import TosaOp
19+
20+
21+
@register_node_visitor
22+
class MaxPool2dVisitor(NodeVisitor):
23+
target = "aten.max_pool2d.default"
24+
25+
def __init__(self, *args):
26+
super().__init__(*args)
27+
28+
def define_node(
29+
self,
30+
node: torch.fx.Node,
31+
tosa_graph: ts.TosaSerializer,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
is_quant_node: bool,
35+
) -> None:
36+
37+
input_tensor = inputs[0]
38+
kernel_size = inputs[1].special
39+
stride = inputs[2].special
40+
41+
try:
42+
padding = [*inputs[3].special, *inputs[3].special]
43+
except IndexError:
44+
padding = [0, 0, 0, 0]
45+
46+
accumulator_type = input_tensor.dtype
47+
48+
if is_quant_node:
49+
# Accumulator type always is int8 when input tensor is an integer type.
50+
accumulator_type = ts.DType.INT8
51+
52+
# Initilize zero point to zero.
53+
input_zp = 0
54+
output_zp = 0
55+
56+
if is_quant_node:
57+
input_zp = get_quant_node_args(
58+
cast(torch.fx.Node, node.all_input_nodes[0])
59+
).zp
60+
output_zp = get_quant_node_args(list(node.users)[0]).zp
61+
62+
attr = ts.TosaSerializerAttribute()
63+
attr.PoolAttribute(
64+
kernel=kernel_size,
65+
stride=stride,
66+
pad=padding,
67+
input_zp=input_zp,
68+
output_zp=output_zp,
69+
accum_dtype=accumulator_type,
70+
)
71+
72+
tosa_graph.addOperator(
73+
TosaOp.Op().MAX_POOL2D,
74+
[input_tensor.name],
75+
[output.name],
76+
attr,
77+
)

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
147147
# TODO: remove?
148148
torch.ops.aten.adaptive_avg_pool2d.default,
149149
torch.ops.aten.avg_pool2d.default,
150+
torch.ops.aten.max_pool2d.default,
150151
torch.ops.aten.full.default,
151152
torch.ops.aten.flatten.using_ints,
152153
torch.ops.aten.dropout.default,

backends/arm/test/common.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,17 @@ def pytest_sessionfinish(session, exitstatus):
9191

9292
# ==== End of Pytest hooks =====
9393

94+
# ==== Custom Pytest decorators =====
95+
96+
97+
def expectedFailureOnFVP(test_item):
98+
if is_option_enabled("corstone300"):
99+
test_item.__unittest_expecting_failure__ = True
100+
return test_item
101+
102+
103+
# ==== End of Custom Pytest decorators =====
104+
94105

95106
def load_libquantized_ops_aot_lib():
96107
so_ext = {
@@ -181,19 +192,15 @@ def get_tosa_compile_spec_unbuilt(
181192
the compile spec before calling .build() to finalize it.
182193
"""
183194
if not custom_path:
184-
intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
185-
prefix="arm_tosa_"
186-
)
187-
else:
188-
intermediate_path = custom_path
195+
custom_path = maybe_get_tosa_collate_path()
189196

190-
if not os.path.exists(intermediate_path):
191-
os.makedirs(intermediate_path, exist_ok=True)
197+
if custom_path is not None and not os.path.exists(custom_path):
198+
os.makedirs(custom_path, exist_ok=True)
192199
compile_spec_builder = (
193200
ArmCompileSpecBuilder()
194201
.tosa_compile_spec()
195202
.set_permute_memory_format(permute_memory_to_nhwc)
196-
.dump_intermediate_artifacts_to(intermediate_path)
203+
.dump_intermediate_artifacts_to(custom_path)
197204
)
198205

199206
return compile_spec_builder

backends/arm/test/misc/test_debug_feats.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ def test_numerical_diff_prints(self):
107107
ArmTester(
108108
model,
109109
example_inputs=model.get_inputs(),
110-
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
110+
compile_spec=common.get_tosa_compile_spec(
111+
permute_memory_to_nhwc=True,
112+
custom_path=tempfile.mkdtemp("diff_print_test"),
113+
),
111114
)
112115
.export()
113116
.to_edge()

backends/arm/test/ops/test_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int):
121121
def test_cat_4d_tosa_MI(self):
122122
square = torch.ones((2, 2, 2, 2))
123123
for dim in range(-3, 3):
124-
test_data = ((square, square), dim)
124+
test_data = ((square, square.clone()), dim)
125125
self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)
126126

127127
@parameterized.expand(Cat.test_parameters)

0 commit comments

Comments
 (0)