Skip to content

Commit d10bbd1

Browse files
committed
Update base for Update on "[XNNPACK][Partitioner] SDPA Config"
We add the SDPA Config here for partitioner. Currently there is an issue with SDPA when used from the FairSeq Multihead attention models, so I currently have it disabled for the base partitioner until we resolve that. Otherwise, for our tests, we can use the SDPA correctly from there. We have to track D60553559. Will follow up on this later. Differential Revision: [D60323285](https://our.internmc.facebook.com/intern/diff/D60323285/) [ghstack-poisoned]
2 parents c2caa04 + 9e478e8 commit d10bbd1

35 files changed

+1878
-91
lines changed

.ci/scripts/test_llava.sh

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -exu
9+
# shellcheck source=/dev/null
10+
11+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
12+
PYTHON_EXECUTABLE=python3
13+
fi
14+
15+
cmake_install_executorch_libraries() {
16+
cmake \
17+
-DCMAKE_INSTALL_PREFIX=cmake-out \
18+
-DCMAKE_BUILD_TYPE=Debug \
19+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
20+
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
21+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
22+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
23+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
24+
-DEXECUTORCH_BUILD_XNNPACK=ON \
25+
-DEXECUTORCH_DO_NOT_USE_CXX11_ABI=ON \
26+
-Bcmake-out .
27+
28+
29+
cmake --build cmake-out -j9 --target install --config Debug
30+
}
31+
32+
cmake_build_llava_runner() {
33+
dir=examples/models/llava
34+
python_lib=$($PYTHON_EXECUTABLE -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')
35+
36+
cmake \
37+
-DCMAKE_INSTALL_PREFIX=cmake-out \
38+
-DCMAKE_BUILD_TYPE=Debug \
39+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
40+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
41+
-DEXECUTORCH_BUILD_XNNPACK=ON \
42+
-DCMAKE_PREFIX_PATH="$python_lib" \
43+
-Bcmake-out/${dir} \
44+
${dir}
45+
46+
47+
cmake --build cmake-out/${dir} -j9 --config Debug
48+
}
49+
50+
# only export the one without custom op for now since it's
51+
export_llava() {
52+
echo "Starting to export Llava. This will take about 6 mins"
53+
$PYTHON_EXECUTABLE -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts
54+
}
55+
56+
run_and_verify() {
57+
NOW=$(date +"%H:%M:%S")
58+
echo "Starting to run llava runner at ${NOW}"
59+
if [[ ! -f "llava.pte" ]]; then
60+
echo "Export failed. Abort"
61+
exit 1
62+
fi
63+
if [[ ! -f "image.pt" ]]; then
64+
echo "image.pt is missing."
65+
exit 1
66+
fi
67+
if [[ ! -f "tokenizer.bin" ]]; then
68+
echo "tokenizer.bin is missing."
69+
exit 1
70+
fi
71+
RUNTIME_ARGS="--model_path=llava.pte \
72+
--tokenizer_path=tokenizer.bin \
73+
--image_path=image.pt \
74+
--prompt=ASSISTANT: \
75+
--temperature=0 \
76+
--seq_len=650"
77+
cmake-out/examples/models/llava/llava_main ${RUNTIME_ARGS} > result.txt
78+
# verify result.txt
79+
RESULT=$(cat result.txt)
80+
# set the expected prefix to be the same as prompt because there's a bug in sdpa_with_kv_cache that causes <unk> tokens.
81+
EXPECTED_PREFIX="ASSISTANT:"
82+
if [[ "${RESULT}" == *"${EXPECTED_PREFIX}"* ]]; then
83+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
84+
echo "Actual result: ${RESULT}"
85+
echo "Success"
86+
exit 0
87+
else
88+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
89+
echo "Actual result: ${RESULT}"
90+
echo "Failure; results not the same"
91+
exit 1
92+
fi
93+
}
94+
95+
cmake_install_executorch_libraries
96+
cmake_build_llava_runner
97+
export_llava
98+
run_and_verify

.github/workflows/pull.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ jobs:
187187
# Test selective build
188188
PYTHON_EXECUTABLE=python bash examples/selective_build/test_selective_build.sh "${BUILD_TOOL}"
189189
190-
test-export-llava-linux:
190+
test-llava-runner-linux:
191191
name: test-export-llava-linux
192192
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
193193
strategy:
@@ -215,6 +215,9 @@ jobs:
215215
# run python unittest
216216
python -m unittest examples.models.llava.test.test_llava
217217
218+
# run e2e (export, tokenizer and runner)
219+
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llava.sh
220+
218221
test-quantized-aot-lib-linux:
219222
name: test-quantized-aot-lib-linux
220223
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ if(EXECUTORCH_ENABLE_EVENT_TRACER)
130130
add_definitions(-DET_EVENT_TRACER_ENABLED)
131131
endif()
132132

133+
option(EXECUTORCH_DO_NOT_USE_CXX11_ABI "Define _GLIBCXX_USE_CXX11_ABI=0 if ON"
134+
OFF
135+
)
136+
if(EXECUTORCH_DO_NOT_USE_CXX11_ABI)
137+
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
138+
endif()
133139
# -ffunction-sections -fdata-sections: breaks function and data into sections so
134140
# they can be properly gc'd. -s: strip symbol. -fno-exceptions -fno-rtti:
135141
# disables exceptions and runtime type.

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,23 @@ def test_vit_skip_conv(self):
6868
)
6969
)
7070

71+
conv_block = ["aten.convolution.default", "executorch_call_delegate"]
72+
safe_softmax_block = [
73+
"getitem",
74+
"getitem",
75+
"getitem",
76+
"getitem",
77+
"aten.any.dim",
78+
"executorch_call_delegate",
79+
]
80+
final_block = ["getitem"]
81+
total = conv_block + 12 * safe_softmax_block + final_block
82+
7183
assert [
7284
node.target.__name__
7385
for node in delegated_program_manager.exported_program().graph.nodes
7486
if node.op == "call_function"
75-
] == [
76-
"aten.convolution.default",
77-
"executorch_call_delegate",
78-
"getitem",
79-
]
87+
] == total
8088

8189

8290
if __name__ == "__main__":

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4343
exir_ops.edge.aten.hardtanh.default,
4444
exir_ops.edge.aten.convolution.default,
4545
exir_ops.edge.aten.div.Tensor,
46+
exir_ops.edge.aten.split_with_sizes_copy.default,
4647
exir_ops.edge.aten.full.default,
4748
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
4849
exir_ops.edge.aten.avg_pool2d.default,

backends/arm/operators/op_slice.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def define_node(
4040
shape = input_node.shape
4141
dim = dim.number
4242
end = (shape[dim] + end.number) % shape[dim]
43+
if end == 0:
44+
end = shape[dim]
4345
size = end - start.number
4446
assert size > 0
4547
assert size <= shape[dim]

backends/arm/passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from executorch.backends.arm.passes.convert_expand_copy_to_repeat import (
1313
ConvertExpandCopyToRepeatPass,
1414
)
15+
from executorch.backends.arm.passes.convert_split_to_slice import (
16+
ConvertSplitToSlicePass,
17+
)
1518
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
1619
from executorch.exir.backend.compile_spec_schema import CompileSpec
1720
from executorch.exir.pass_manager import PassManager
@@ -28,6 +31,7 @@ def transform_to_backend_pipeline(
2831
"""Apply passes before transforming program to backend"""
2932
self.add_pass(RemoveClonePass())
3033
self.add_pass(ConvertExpandCopyToRepeatPass())
34+
self.add_pass(ConvertSplitToSlicePass())
3135
for spec in compile_spec:
3236
if spec.key == "permute_memory_format":
3337
memory_format = spec.value.decode()
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch.fx
8+
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
13+
class ConvertSplitToSlicePass(ExportPass):
14+
"""
15+
Replace a split operation with many slice operations.
16+
"""
17+
18+
split_ops = (
19+
exir_ops.edge.aten.split_with_sizes_copy.default,
20+
exir_ops.edge.aten.split_copy.Tensor,
21+
)
22+
slice = exir_ops.edge.aten.slice_copy.Tensor
23+
24+
def call(self, graph_module: torch.fx.GraphModule):
25+
graph = graph_module.graph
26+
for node in graph.nodes:
27+
if node.target not in self.split_ops:
28+
continue
29+
30+
# Get useful variables
31+
split_node = node
32+
input_node = split_node.all_input_nodes[0]
33+
output_nodes = split_node.users.copy()
34+
_, shape, _ = extract_tensor_meta(input_node.meta)
35+
rank = len(shape)
36+
split_lengths = split_node.args[1]
37+
dim = split_node.args[2] if len(split_node.args) > 2 else 0
38+
dim = (dim + rank) % rank
39+
40+
assert (
41+
sum(split_lengths) == shape[dim]
42+
), "Given split lengths don't sum up to the size of the dimension."
43+
44+
# Convert split argument 'split_lengths' to slice arguments start and end.
45+
starts = [0] * len(split_lengths)
46+
ends = [0] * len(split_lengths)
47+
start = 0
48+
end = 0
49+
for i, split_length in enumerate(split_lengths):
50+
end = start + split_length
51+
starts[i] = start
52+
ends[i] = end
53+
start = end
54+
55+
# Output nodes are of type getitem
56+
# Create one slice node for each output node with matching argumetns.
57+
with graph_module.graph.inserting_before(split_node):
58+
for output_node in output_nodes:
59+
index = output_node.args[1]
60+
slice_node = graph.create_node(
61+
"call_function",
62+
self.slice,
63+
(input_node, dim, starts[index], ends[index]),
64+
)
65+
slice_node.meta = split_node.meta.copy()
66+
slice_node.meta["val"] = slice_node.meta["val"][index]
67+
output_node.replace_input_with(split_node, slice_node)
68+
graph.eliminate_dead_code()
69+
graph_module.recompile()
70+
return PassResult(graph_module, True)

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Utility functions for ArmQuantizer
1010
#
1111

12+
import operator
1213
from typing import Callable, cast, List
1314

1415
import torch
@@ -141,8 +142,11 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
141142
torch.ops.aten.view_copy.default,
142143
torch.ops.aten.view.default,
143144
torch.ops.aten.slice.Tensor,
145+
torch.ops.aten.split.Tensor,
146+
torch.ops.aten.split_with_sizes.default,
144147
torch.ops.aten.flatten.using_ints,
145148
torch.ops.aten.dropout.default,
149+
operator.getitem,
146150
]
147151

148152

backends/arm/test/ops/test_slice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def forward(self, x: torch.Tensor):
3333
elif x.dim() == 3:
3434
return x[0:7, 0:1, 0:8]
3535
elif x.dim() == 4:
36-
return x[:, 2:5, 3:5, 4:5]
36+
return x[:, 2:5, 3:5, 4:10]
3737

3838
def _test_slice_tosa_MI_pipeline(
3939
self, module: torch.nn.Module, test_data: torch.Tensor

0 commit comments

Comments
 (0)