Skip to content

Commit c82c5c2

Browse files
authored
Merge branch 'pytorch:main' into toupstream/select_op
2 parents 1323c7c + e2526cc commit c82c5c2

File tree

87 files changed

+3706
-819
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+3706
-819
lines changed

.github/scripts/propose_ghstack_orig_pr.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def parse_args():
2626
required=True,
2727
)
2828
parser.add_argument(
29-
"--pr",
30-
type=int,
31-
help="Number of the PR in the stack to check and create corresponding PR",
29+
"--ref",
30+
type=str,
31+
help="Ref fo PR in the stack to check and create corresponding PR",
3232
required=True,
3333
)
3434
return parser.parse_args()
@@ -68,12 +68,18 @@ def extract_stack_from_body(pr_body: str) -> List[int]:
6868
return list(reversed(prs))
6969

7070

71-
def get_pr_stack_from_number(pr_number: int, repo: Repository) -> List[int]:
71+
def get_pr_stack_from_number(ref: str, repo: Repository) -> List[int]:
72+
if ref.isnumeric():
73+
pr_number = int(ref)
74+
else:
75+
branch_name = ref.replace("refs/heads/", "")
76+
pr_number = repo.get_branch(branch_name).commit.get_pulls()[0].number
77+
7278
pr_stack = extract_stack_from_body(repo.get_pull(pr_number).body)
7379

7480
if not pr_stack:
7581
raise Exception(
76-
f"Could not find PR stack in body of #{pr_number}. "
82+
f"Could not find PR stack in body of ref. "
7783
+ "Please make sure that the PR was created with ghstack."
7884
)
7985

@@ -100,14 +106,15 @@ def create_prs_for_orig_branch(pr_stack: List[int], repo: Repository):
100106
ghstack PR base: https://github.com/pytorch/executorch/tree/{pr.base.ref}
101107
ghstack PR head: https://github.com/pytorch/executorch/tree/{pr.head.ref}
102108
Merge bot PR base: https://github.com/pytorch/executorch/tree/{orig_branch_merge_base}
103-
Merge bot PR head: https://github.com/pytorch/executorch/tree/{orig_branch_merge_head}"""
109+
Merge bot PR head: https://github.com/pytorch/executorch/tree/{orig_branch_merge_head}
110+
@diff-train-skip-merge"""
104111

105112
existing_orig_pr = repo.get_pulls(
106113
head="pytorch:" + orig_branch_merge_head,
107114
base=orig_branch_merge_base,
108-
state="open",
115+
state="all",
109116
)
110-
if existing_orig_pr.totalCount > 0:
117+
if existing_orig_pr.totalCount > 0 and existing_orig_pr[0].title == pr.title:
111118
print(
112119
f"PR for {orig_branch_merge_head} already exists {existing_orig_pr[0]}"
113120
)
@@ -128,7 +135,7 @@ def main():
128135

129136
with Github(auth=Auth.Token(os.environ["GITHUB_TOKEN"])) as gh:
130137
repo = gh.get_repo(args.repo)
131-
create_prs_for_orig_branch(get_pr_stack_from_number(args.pr, repo), repo)
138+
create_prs_for_orig_branch(get_pr_stack_from_number(args.ref, repo), repo)
132139

133140

134141
if __name__ == "__main__":

.github/workflows/android.yml renamed to .github/workflows/_android.yml

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,9 @@
11
name: Android
22

33
on:
4-
push:
5-
branches:
6-
- main
7-
- release/*
8-
tags:
9-
- ciflow/android/*
10-
pull_request:
11-
paths:
12-
- .ci/docker/**
13-
- .github/workflows/android.yml
14-
- build/*android*.sh
15-
- install_requirements.sh
16-
- examples/demo-apps/android/**
17-
- extension/android/**
18-
- extension/benchmark/android/**
19-
- extension/module/**
4+
workflow_call:
205
workflow_dispatch:
216

22-
concurrency:
23-
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
24-
cancel-in-progress: true
25-
267
jobs:
278
build-llm-demo:
289
name: build-llm-demo

.github/workflows/ghstack_land.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ on:
1515
- 'gh/mcr229/[0-9]+/base'
1616
- 'gh/swolchok/[0-9]+/base'
1717
- 'gh/SS-JIA/[0-9]+/base'
18+
- 'gh/trivedivivek/[0-9]+/base'
1819

1920
jobs:
2021
ghstack_merge_to_main:
@@ -32,9 +33,7 @@ jobs:
3233
run: |
3334
pip install pygithub
3435
35-
PR_NUMBER=$(echo "$GITHUB_REF" | grep -oE '[0-9]+')
36-
37-
python .github/scripts/propose_ghstack_orig_pr.py --pr $PR_NUMBER --repo pytorch/executorch
36+
python .github/scripts/propose_ghstack_orig_pr.py --ref $GITHUB_REF --repo pytorch/executorch
3837
env:
3938
GITHUB_TOKEN: ${{ secrets.GH_PYTORCHBOT_CHERRY_PICK_TOKEN }}
4039
GITHUB_REF: ${{ github.ref }}

.github/workflows/pull.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ jobs:
347347
exit 1
348348
fi
349349
350+
android:
351+
uses: ./.github/workflows/_android.yml
352+
350353
unittest:
351354
uses: ./.github/workflows/_unittest.yml
352355
with:

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,50 @@
99
from typing import cast
1010

1111
import torch
12-
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
12+
from executorch.backends.arm._passes.arm_pass_utils import (
13+
create_node,
14+
get_first_fake_tensor,
15+
)
1316
from executorch.backends.arm.tosa_quant_utils import dq_op
1417
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
18+
from executorch.exir.dialects._ops import ops as exir_ops
1519
from executorch.exir.pass_base import ExportPass, PassResult
20+
from torch.library import impl, Library
21+
22+
# Define lib with passthrough operators. The operators have no real meaning in edge IR
23+
# except for argument validaiton and a passthrough output. The operators will be used
24+
# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
25+
# the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
26+
lib = Library("passthrough_to_tosa", "DEF")
27+
# For operators that change the rank of the input, such as unsqueeze and squeeze, we may need
28+
# to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient
29+
# as we also need transpose the data into the correct data format.
30+
# By utilizing an edge IR passthrough operator we can keep the edge program in
31+
# channels-first/contiguous and get the desired behavior in the TOSA lowering.
32+
lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor")
33+
34+
35+
@impl(lib, "_transpose")
36+
def _transpose_impl(*args, **kwargs):
37+
# Validate length of dim_order array
38+
dim = args[1]
39+
assert len(dim) <= 4
40+
# Pass-through in edge-IR
41+
return args[0]
1642

1743

1844
class AnnotateChannelsLastDimOrder(ExportPass):
1945
"""
2046
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
21-
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes.
22-
The annotated tosa_dim_order is used to permute the node's shape such that it
23-
gives a TOSA-compliant shape.
47+
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
48+
when a transition between 3D and 4D tensors happen.
49+
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
2450
"""
2551

52+
NHWC_order = (0, 2, 3, 1)
53+
NHWC_inverse_order = (0, 3, 1, 2)
54+
HWCM_order = (2, 3, 0, 1)
55+
2656
def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
2757
"""
2858
returns True for dq and w in the following sequences;
@@ -49,20 +79,56 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
4979

5080
return False
5181

82+
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
83+
for node in graph_module.graph.nodes:
84+
if node.op != "call_function":
85+
continue
86+
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
87+
input_node = node.args[0]
88+
if input_node.meta["val"].dim() == 4:
89+
with graph_module.graph.inserting_before(node):
90+
permute_node = create_node(
91+
graph_module.graph,
92+
torch.ops.passthrough_to_tosa._transpose,
93+
args=(input_node, list(self.NHWC_inverse_order)),
94+
)
95+
permute_node.meta["tosa_dim_order"] = tuple(
96+
range(len(input_node.meta["val"].size()))
97+
)
98+
node.replace_input_with(input_node, permute_node)
99+
100+
if node.target == exir_ops.edge.aten.unsqueeze_copy.default:
101+
if node.meta["val"].dim() == 4:
102+
with graph_module.graph.inserting_after(node):
103+
permute_node = create_node(
104+
graph_module.graph,
105+
torch.ops.passthrough_to_tosa._transpose,
106+
args=(node, list(self.NHWC_order)),
107+
)
108+
permute_node.meta["tosa_dim_order"] = self.NHWC_order
109+
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
110+
users = [user for user in node.users if user != permute_node]
111+
for user in users:
112+
user.replace_input_with(node, permute_node)
113+
52114
def call(self, graph_module: torch.fx.GraphModule):
53-
NHWC_Order = (0, 2, 3, 1)
54-
HWCM_Order = (2, 3, 0, 1)
55115
for node in graph_module.graph.nodes:
56116
node_data = get_first_fake_tensor(node).data
57117

58-
if len(node_data.shape) == 4:
59-
dim_order = NHWC_Order
118+
if node_data.dim() == 4:
119+
dim_order = self.NHWC_order
60120
if self.is_weight_node_for_depthwise_conv2d(node):
61121
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
62122
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
63-
dim_order = HWCM_Order
123+
dim_order = self.HWCM_order
64124
else:
65125
dim_order = tuple(range(node_data.dim()))
66126
node.meta["tosa_dim_order"] = dim_order
127+
# Take care of cases when:
128+
# 4D (NHWC) -> >4D (NCH)
129+
# 3D (NCH) -> 4D (NHWC)
130+
self.insert_tosa_transposes(graph_module)
67131
graph_module.recompile()
132+
graph_module = super().call(graph_module).graph_module
133+
68134
return PassResult(graph_module, True)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
ConvertSplitToSlicePass,
2020
)
2121
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
22+
from executorch.backends.arm._passes.decompose_layernorm_pass import (
23+
DecomposeLayerNormPass,
24+
)
25+
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
26+
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
2227
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
2328
InsertSqueezeAfterSumPass,
2429
)
@@ -53,7 +58,10 @@ def transform_to_backend_pipeline(
5358
self.add_pass(SizeAdjustConv2DPass())
5459
self.add_pass(RemoveClonePass())
5560
self.add_pass(ConvertExpandCopyToRepeatPass())
61+
self.add_pass(DecomposeLayerNormPass())
62+
self.add_pass(DecomposeVarPass())
5663
self.add_pass(ConvertMeanDimToAveragePool())
64+
self.add_pass(DecomposeMeanDimPass())
5765
self.add_pass(MatchArgRanksPass(exported_program))
5866
self.add_pass(DecomposeDivPass())
5967
self.add_pass(InsertSqueezeAfterSumPass())
@@ -67,6 +75,9 @@ def transform_to_backend_pipeline(
6775
return self._transform(exported_program.graph_module)
6876

6977
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
78+
self.add_pass(DecomposeLayerNormPass())
79+
self.add_pass(DecomposeVarPass())
80+
self.add_pass(DecomposeMeanDimPass())
7081
self.add_pass(ScalarsToAttributePass())
7182
self.add_pass(DecomposeDivPass())
7283
return self._transform(graph_module)

0 commit comments

Comments
 (0)