Skip to content

Commit 1bcb626

Browse files
committed
Merge remote-tracking branch 'origin/main' into toupstream/select_op
2 parents c82c5c2 + cbfdf78 commit 1bcb626

File tree

27 files changed

+538
-276
lines changed

27 files changed

+538
-276
lines changed

.ci/docker/ci_commit_pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
d1b87e26e5c4343f5b56bb1e6f89b479b389bfac
1+
export-D64151426

.github/workflows/apple-perf.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ jobs:
7676
# on-demand and periodic benchmarking.
7777
CRON_DEFAULT_MODELS: "stories110M,mv3,mv2,ic4,ic3,resnet50,edsr,mobilebert,w2l"
7878
CRON_DEFAULT_DEVICES: "apple_iphone_15"
79-
CRON_DEFAULT_DELEGATES: "nnpack,coreml,mps"
79+
CRON_DEFAULT_DELEGATES: "xnnpack,coreml,mps"
8080
run: |
8181
set -ex
8282
MODELS="${{ inputs.models }}"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Check out the [Getting Started](https://pytorch.org/executorch/stable/getting-st
2525
Check out the examples of [Llama](./examples/models/llama/README.md), [Llava](./examples/models/llava/README.md) and [other models](./examples/README.md) running on edge devices using ExecuTorch.
2626

2727

28-
**[UPDATE - 09/25]** We have added support for running [Llama 3.2 1B/3B](./examples/models/llama/README.md) models via ExecuTorch.
28+
**[UPDATE - 10/24]** We have added support for running [Llama 3.2 Quantized 1B/3B](./examples/models/llama/README.md) models via ExecuTorch.
2929

3030
## Feedback
3131

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from executorch.backends.arm._passes.arm_pass_utils import (
1313
create_node,
1414
get_first_fake_tensor,
15+
insert_q_dq_pair,
1516
)
16-
from executorch.backends.arm.tosa_quant_utils import dq_op
17+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1718
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1819
from executorch.exir.dialects._ops import ops as exir_ops
1920
from executorch.exir.pass_base import ExportPass, PassResult
@@ -79,37 +80,89 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
7980

8081
return False
8182

83+
def insert_input_transpose(self, node, input_node, graph_module):
84+
quantize = input_node.target == dq_op
85+
q_params = input_node.args[1:] if quantize else None
86+
with graph_module.graph.inserting_before(node):
87+
permute_node = create_node(
88+
graph_module.graph,
89+
torch.ops.passthrough_to_tosa._transpose,
90+
args=(input_node, list(self.NHWC_inverse_order)),
91+
quantize=quantize,
92+
q_params=q_params,
93+
)
94+
node.replace_input_with(input_node, permute_node)
95+
96+
permute_node.meta["tosa_dim_order"] = tuple(
97+
range(len(input_node.meta["val"].size()))
98+
)
99+
100+
def insert_output_transpose(self, node, graph_module):
101+
with graph_module.graph.inserting_after(node):
102+
permute_node = create_node(
103+
graph_module.graph,
104+
torch.ops.passthrough_to_tosa._transpose,
105+
args=(node, list(self.NHWC_order)),
106+
)
107+
permute_node.meta["tosa_dim_order"] = self.NHWC_order
108+
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
109+
users = [user for user in node.users if user != permute_node]
110+
for user in users:
111+
user.replace_input_with(node, permute_node)
112+
113+
quantize = node.args[0] == q_op
114+
if quantize:
115+
q_params = node.args[0].args[1:]
116+
insert_q_dq_pair(graph_module.graph, node, q_params)
117+
82118
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
119+
"""
120+
Reshape operations are not equivalent in NCHW and NHWC.
121+
To get around this, transposes need to be added if the previous or new shape
122+
fulfil the following condition:
123+
C > 1 and (H or W > 1)
124+
125+
This is relevant for the following operations;
126+
squeeze: 4D -> 3D
127+
unsqueeze: <4D -> 4D
128+
view: <4D -> 4D
129+
view: 4D -> <4D
130+
view: 4D -> 4D
131+
"""
132+
133+
def transpose_condition(shape):
134+
if len(shape) != 4:
135+
return False
136+
C = shape[1]
137+
H = shape[2]
138+
W = shape[3]
139+
return C > 1 and (H > 1 or W > 1)
140+
83141
for node in graph_module.graph.nodes:
84142
if node.op != "call_function":
85143
continue
86144
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
87145
input_node = node.args[0]
88-
if input_node.meta["val"].dim() == 4:
89-
with graph_module.graph.inserting_before(node):
90-
permute_node = create_node(
91-
graph_module.graph,
92-
torch.ops.passthrough_to_tosa._transpose,
93-
args=(input_node, list(self.NHWC_inverse_order)),
94-
)
95-
permute_node.meta["tosa_dim_order"] = tuple(
96-
range(len(input_node.meta["val"].size()))
97-
)
98-
node.replace_input_with(input_node, permute_node)
99-
100-
if node.target == exir_ops.edge.aten.unsqueeze_copy.default:
101-
if node.meta["val"].dim() == 4:
102-
with graph_module.graph.inserting_after(node):
103-
permute_node = create_node(
104-
graph_module.graph,
105-
torch.ops.passthrough_to_tosa._transpose,
106-
args=(node, list(self.NHWC_order)),
107-
)
108-
permute_node.meta["tosa_dim_order"] = self.NHWC_order
109-
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
110-
users = [user for user in node.users if user != permute_node]
111-
for user in users:
112-
user.replace_input_with(node, permute_node)
146+
input_shape = input_node.meta["val"].shape
147+
if transpose_condition(input_shape):
148+
self.insert_input_transpose(node, input_node, graph_module)
149+
150+
elif node.target == exir_ops.edge.aten.unsqueeze_copy.default:
151+
output_shape = node.meta["val"].shape
152+
if transpose_condition(output_shape):
153+
self.insert_output_transpose(node, graph_module)
154+
155+
elif node.target == exir_ops.edge.aten.view_copy.default:
156+
input_node = node.args[0]
157+
158+
old_shape = input_node.meta["val"].shape
159+
new_shape = node.meta["val"].shape
160+
161+
if transpose_condition(old_shape):
162+
self.insert_input_transpose(node, input_node, graph_module)
163+
164+
if transpose_condition(new_shape):
165+
self.insert_output_transpose(node, graph_module)
113166

114167
def call(self, graph_module: torch.fx.GraphModule):
115168
for node in graph_module.graph.nodes:

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
DecomposeLayerNormPass,
2424
)
2525
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
26+
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
27+
DecomposeSoftmaxesPass,
28+
)
2629
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
2730
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
2831
InsertSqueezeAfterSumPass,
@@ -66,6 +69,7 @@ def transform_to_backend_pipeline(
6669
self.add_pass(DecomposeDivPass())
6770
self.add_pass(InsertSqueezeAfterSumPass())
6871
self.add_pass(ConvertSplitToSlicePass())
72+
self.add_pass(DecomposeSoftmaxesPass())
6973
for spec in compile_spec:
7074
if spec.key == "permute_memory_format":
7175
memory_format = spec.value.decode()
@@ -75,9 +79,10 @@ def transform_to_backend_pipeline(
7579
return self._transform(exported_program.graph_module)
7680

7781
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
82+
self.add_pass(ScalarsToAttributePass())
7883
self.add_pass(DecomposeLayerNormPass())
7984
self.add_pass(DecomposeVarPass())
8085
self.add_pass(DecomposeMeanDimPass())
81-
self.add_pass(ScalarsToAttributePass())
8286
self.add_pass(DecomposeDivPass())
87+
self.add_pass(DecomposeSoftmaxesPass())
8388
return self._transform(graph_module)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
# For BI case
12+
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
13+
14+
# For MI case
15+
edge_softmax = (
16+
exir_ops.edge.aten._softmax.default,
17+
exir_ops.edge.aten._log_softmax.default,
18+
)
19+
20+
log_softmax = (torch.ops.aten.log_softmax.int, exir_ops.edge.aten._log_softmax.default)
21+
22+
23+
def get_logsoftmax_ops(op) -> tuple:
24+
"""
25+
Returns the the (log_op, expo_op, sum_op, reciprocal_op), where the ops depends on if
26+
the logsoftmax op is in exir_ops torch.ops.aten.
27+
"""
28+
if op in edge_softmax:
29+
return (
30+
exir_ops.edge.aten.log.default,
31+
exir_ops.edge.aten.exp.default,
32+
exir_ops.edge.aten.sum.dim_IntList,
33+
exir_ops.edge.aten.reciprocal.default,
34+
exir_ops.edge.aten.mul.Tensor,
35+
)
36+
if op in torch_softmax:
37+
return (
38+
torch.ops.aten.log.default,
39+
torch.ops.aten.exp.default,
40+
torch.ops.aten.sum.dim_IntList,
41+
torch.ops.aten.reciprocal.default,
42+
torch.ops.aten.mul.Tensor,
43+
)
44+
raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")
45+
46+
47+
class DecomposeSoftmaxesPass(ExportPass):
48+
"""
49+
This pass decomposes log softmax or softmax into more primitive ops.
50+
51+
Example:
52+
%op1 = exp(x)
53+
%op2 = sum(%op1, dim)
54+
%op3 = reciprocal(%op2)
55+
%op4 = mul(%op1, %op3)
56+
(in logsoftmax case: %op5 = log(%op4))
57+
"""
58+
59+
def call_operator(self, op, args, kwargs, meta):
60+
if op not in torch_softmax + edge_softmax:
61+
return super().call_operator(op, args, kwargs, meta)
62+
63+
log_op, exp_op, sum_op, reciprocal_op, mul_op = get_logsoftmax_ops(op)
64+
65+
_input = args[0]
66+
dim = [args[1]]
67+
68+
op1 = super().call_operator(exp_op, (_input,), {}, meta)
69+
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta)
70+
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta)
71+
op4 = super().call_operator(mul_op, (op1, op3), {}, meta)
72+
if op in log_softmax:
73+
op4 = super().call_operator(log_op, (op4,), {}, meta)
74+
return op4

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6363
exir_ops.edge.aten.rsqrt.default,
6464
exir_ops.edge.aten._softmax.default,
6565
exir_ops.edge.aten.select_copy.int,
66+
exir_ops.edge.aten._log_softmax.default,
6667
exir_ops.edge.aten.slice_copy.Tensor,
6768
exir_ops.edge.aten.sub.Tensor,
6869
exir_ops.edge.aten.sum.dim_IntList,

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
op_select,
3232
op_sigmoid,
3333
op_slice,
34-
op_softmax,
3534
op_squeeze,
3635
op_sub,
3736
op_sum,

backends/arm/operators/op_exp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def define_node(
4242
) -> None:
4343

4444
assert len(node.all_input_nodes) == 1
45-
assert len(node.users) == 1
4645

4746
if is_quant_node:
4847
# Assume quantized input is 8 bit.

backends/arm/operators/op_softmax.py

Lines changed: 0 additions & 99 deletions
This file was deleted.

backends/arm/test/misc/test_debug_feats.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,8 @@ def test_numerical_diff_prints(self):
107107
ArmTester(
108108
model,
109109
example_inputs=model.get_inputs(),
110-
compile_spec=common.get_tosa_compile_spec(),
110+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
111111
)
112-
.quantize()
113112
.export()
114113
.to_edge()
115114
.partition()

0 commit comments

Comments
 (0)