Skip to content

Commit e1e6535

Browse files
committed
Update base for Update on "Transform model to be able to use Attention Sink"
This PR adds necessary functions for transforming the model to be able to use Attention Sink. Differential Revision: [D65571289](https://our.internmc.facebook.com/intern/diff/D65571289/) [ghstack-poisoned]
2 parents f7ec0af + ddec0c7 commit e1e6535

File tree

96 files changed

+5202
-714
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+5202
-714
lines changed

.ci/scripts/test_llama.sh

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ UPLOAD_DIR="${UPLOAD_DIR:-}"
5151
# Default PT2E_QUANTIZE to empty string if not set
5252
PT2E_QUANTIZE="${PT2E_QUANTIZE:-}"
5353

54+
# Default CMake Build Type to release mode
55+
CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release}
56+
5457
if [[ $# -lt 4 ]]; then # Assuming 4 mandatory args
5558
echo "Expecting atleast 4 positional arguments"
5659
echo "Usage: [...]"
@@ -143,7 +146,7 @@ cmake_install_executorch_libraries() {
143146
rm -rf cmake-out
144147
retry cmake \
145148
-DCMAKE_INSTALL_PREFIX=cmake-out \
146-
-DCMAKE_BUILD_TYPE=Debug \
149+
-DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \
147150
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
148151
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
149152
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
@@ -157,22 +160,22 @@ cmake_install_executorch_libraries() {
157160
-DQNN_SDK_ROOT="$QNN_SDK_ROOT" \
158161
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
159162
-Bcmake-out .
160-
cmake --build cmake-out -j9 --target install --config Debug
163+
cmake --build cmake-out -j9 --target install --config "$CMAKE_BUILD_TYPE"
161164
}
162165

163166
cmake_build_llama_runner() {
164167
echo "Building llama runner"
165168
dir="examples/models/llama"
166169
retry cmake \
167170
-DCMAKE_INSTALL_PREFIX=cmake-out \
168-
-DCMAKE_BUILD_TYPE=Debug \
171+
-DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \
169172
-DEXECUTORCH_BUILD_KERNELS_CUSTOM="$CUSTOM" \
170173
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
171174
-DEXECUTORCH_BUILD_XNNPACK="$XNNPACK" \
172175
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
173176
-Bcmake-out/${dir} \
174177
${dir}
175-
cmake --build cmake-out/${dir} -j9 --config Debug
178+
cmake --build cmake-out/${dir} -j9 --config "$CMAKE_BUILD_TYPE"
176179

177180
}
178181

.ci/scripts/test_llava.sh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
set -exu
99
# shellcheck source=/dev/null
1010

11-
BUILD_TYPE=${1:-Debug}
1211
TARGET_OS=${2:-Native}
1312
BUILD_DIR=${3:-cmake-out}
13+
CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release}
1414

15-
echo "Building with BUILD_TYPE: $BUILD_TYPE, TARGET_OS: $TARGET_OS, BUILD_DIR: $BUILD_DIR"
15+
echo "Building with CMAKE_BUILD_TYPE: $CMAKE_BUILD_TYPE, TARGET_OS: $TARGET_OS, BUILD_DIR: $BUILD_DIR"
1616

1717
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
1818
PYTHON_EXECUTABLE=python3
@@ -32,7 +32,7 @@ if hash nproc &> /dev/null; then NPROC=$(nproc); fi
3232

3333
EXECUTORCH_COMMON_CMAKE_ARGS=" \
3434
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
35-
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
35+
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
3636
-DEXECUTORCH_ENABLE_LOGGING=ON \
3737
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
3838
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
@@ -49,7 +49,7 @@ cmake_install_executorch_libraries() {
4949
${EXECUTORCH_COMMON_CMAKE_ARGS} \
5050
-B${BUILD_DIR} .
5151

52-
cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${BUILD_TYPE}
52+
cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${CMAKE_BUILD_TYPE}
5353
}
5454

5555
cmake_install_executorch_libraries_for_android() {
@@ -59,14 +59,14 @@ cmake_install_executorch_libraries_for_android() {
5959
${EXECUTORCH_COMMON_CMAKE_ARGS} \
6060
-B${BUILD_DIR} .
6161

62-
cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${BUILD_TYPE}
62+
cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${CMAKE_BUILD_TYPE}
6363
}
6464

6565

6666
LLAVA_COMMON_CMAKE_ARGS=" \
6767
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
6868
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
69-
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
69+
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
7070
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
7171
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
7272
-DEXECUTORCH_BUILD_XNNPACK=ON"
@@ -81,7 +81,7 @@ cmake_build_llava_runner() {
8181
-B${BUILD_DIR}/${dir} \
8282
${dir}
8383

84-
cmake --build ${BUILD_DIR}/${dir} -j${NPROC} --config ${BUILD_TYPE}
84+
cmake --build ${BUILD_DIR}/${dir} -j${NPROC} --config ${CMAKE_BUILD_TYPE}
8585
}
8686

8787

@@ -98,7 +98,7 @@ cmake_build_llava_runner_for_android() {
9898
-B${BUILD_DIR}/${dir} \
9999
${dir}
100100

101-
cmake --build ${BUILD_DIR}/${dir} -j${NPROC} --config ${BUILD_TYPE}
101+
cmake --build ${BUILD_DIR}/${dir} -j${NPROC} --config ${CMAKE_BUILD_TYPE}
102102
}
103103

104104
# only export the one without custom op for now since it's

.github/workflows/ghstack_land.yml

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,7 @@ on:
33
pull_request:
44
types: [closed]
55
branches:
6-
- 'gh/cccclai/[0-9]+/base'
7-
- 'gh/dbort/[0-9]+/base'
8-
- 'gh/dvorjackz/[0-9]+/base'
9-
- 'gh/guangy10/[0-9]+/base'
10-
- 'gh/helunwencser/[0-9]+/base'
11-
- 'gh/jorgep31415/[0-9]+/base'
12-
- 'gh/kimishpatel/[0-9]+/base'
13-
- 'gh/kirklandsign/[0-9]+/base'
14-
- 'gh/larryliu0820/[0-9]+/base'
15-
- 'gh/lucylq/[0-9]+/base'
16-
- 'gh/manuelcandales/[0-9]+/base'
17-
- 'gh/mcr229/[0-9]+/base'
18-
- 'gh/swolchok/[0-9]+/base'
19-
- 'gh/SS-JIA/[0-9]+/base'
20-
- 'gh/trivedivivek/[0-9]+/base'
6+
- 'gh/*/[0-9]+/base'
217

228
jobs:
239
ghstack_merge_to_main:

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ jobs:
290290
# ${CONDA_RUN} python -m unittest examples.models.llava.test.test_llava
291291

292292
# # run e2e (export, tokenizer and runner)
293-
# PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_llava.sh Release
293+
# PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_llava.sh
294294

295295
test-qnn-model:
296296
name: test-qnn-model

CMakeLists.txt

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,22 @@ if(EXECUTORCH_BUILD_PTHREADPOOL
682682
endif()
683683

684684
if(EXECUTORCH_BUILD_PYBIND)
685+
# Setup RPATH.
686+
# See https://gitlab.kitware.com/cmake/community/-/wikis/doc/cmake/RPATH-handling
687+
if(APPLE)
688+
set(CMAKE_MACOSX_RPATH ON)
689+
set(_rpath_portable_origin "@loader_path")
690+
else()
691+
set(_rpath_portable_origin $ORIGIN)
692+
endif(APPLE)
693+
# Use separate rpaths during build and install phases
694+
set(CMAKE_SKIP_BUILD_RPATH FALSE)
695+
# Don't use the install-rpath during the build phase
696+
set(CMAKE_BUILD_WITH_INSTALL_RPATH FALSE)
697+
set(CMAKE_INSTALL_RPATH "${_rpath_portable_origin}")
698+
# Automatically add all linked folders that are NOT in the build directory to
699+
# the rpath (per library?)
700+
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
685701
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/pybind11)
686702

687703
if(NOT EXECUTORCH_BUILD_EXTENSION_DATA_LOADER)
@@ -765,46 +781,6 @@ if(EXECUTORCH_BUILD_PYBIND)
765781
target_include_directories(portable_lib PRIVATE ${TORCH_INCLUDE_DIRS})
766782
target_compile_options(portable_lib PUBLIC ${_pybind_compile_options})
767783
target_link_libraries(portable_lib PRIVATE ${_dep_libs})
768-
if(APPLE)
769-
# pip wheels will need to be able to find the torch libraries. On Linux, the
770-
# .so has non-absolute dependencies on libs like "libtorch.so" without
771-
# paths; as long as we `import torch` first, those dependencies will work.
772-
# But Apple dylibs do not support non-absolute dependencies, so we need to
773-
# tell the loader where to look for its libraries. The LC_LOAD_DYLIB entries
774-
# for the torch libraries will look like "@rpath/libtorch.dylib", so we can
775-
# add an LC_RPATH entry to look in a directory relative to the installed
776-
# location of our _portable_lib.so file. To see these LC_* values, run
777-
# `otool -l _portable_lib*.so`.
778-
set_target_properties(
779-
portable_lib
780-
PROPERTIES # Assume that this library will be installed in
781-
# `site-packages/executorch/extension/pybindings`, and that
782-
# the torch libs are in `site-packages/torch/lib`.
783-
BUILD_RPATH "@loader_path/../../../torch/lib"
784-
INSTALL_RPATH "@loader_path/../../../torch/lib"
785-
# Assume <executorch> is the root `site-packages/executorch`
786-
# Need to add <executorch>/extension/llm/custom_ops for
787-
# libcustom_ops_aot_lib.dylib
788-
BUILD_RPATH "@loader_path/../../extension/llm/custom_ops"
789-
INSTALL_RPATH "@loader_path/../../extension/llm/custom_ops"
790-
# Need to add <executorch>/kernels/quantized for
791-
# libquantized_ops_aot_lib.dylib
792-
BUILD_RPATH "@loader_path/../../kernels/quantized"
793-
INSTALL_RPATH "@loader_path/../../kernels/quantized"
794-
)
795-
else()
796-
set_target_properties(
797-
portable_lib
798-
PROPERTIES
799-
# Assume <executorch> is the root `site-packages/executorch`
800-
# Need to add <executorch>/extension/llm/custom_ops for
801-
# libcustom_ops_aot_lib
802-
# Need to add <executorch>/kernels/quantized for
803-
# libquantized_ops_aot_lib
804-
BUILD_RPATH
805-
"$ORIGIN:$ORIGIN/../../extension/llm/custom_ops:$ORIGIN/../../kernels/quantized"
806-
)
807-
endif()
808784

809785
install(TARGETS portable_lib
810786
LIBRARY DESTINATION executorch/extension/pybindings

backends/arm/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,14 @@ python_library(
110110
"//executorch/backends/arm/operators:node_visitor",
111111
],
112112
)
113+
114+
python_library(
115+
name = "arm_model_evaluator",
116+
src = [
117+
"util/arm_model_evaluator.py",
118+
],
119+
typing = True,
120+
deps = [
121+
"//caffe2:torch",
122+
]
123+
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
DecomposeSoftmaxesPass,
3030
)
3131
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
32-
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
33-
InsertSqueezeAfterSumPass,
32+
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
33+
KeepDimsFalseToSqueezePass,
3434
)
3535
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
3636
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
@@ -71,7 +71,7 @@ def transform_to_backend_pipeline(
7171
self.add_pass(DecomposeMeanDimPass())
7272
self.add_pass(MatchArgRanksPass(exported_program))
7373
self.add_pass(DecomposeDivPass())
74-
self.add_pass(InsertSqueezeAfterSumPass())
74+
self.add_pass(KeepDimsFalseToSqueezePass())
7575
self.add_pass(ConvertSplitToSlicePass())
7676
self.add_pass(Conv1dUnsqueezePass(exported_program))
7777
self.add_pass(DecomposeSoftmaxesPass())

backends/arm/_passes/arm_pass_utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-unsafe
99

10+
from inspect import isclass
1011
from typing import Optional
1112

1213
import torch
@@ -133,3 +134,60 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
133134
fake_tensor, FakeTensor
134135
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
135136
return fake_tensor
137+
138+
139+
def get_node_arg(args: list | dict, key: int | str | type, default_value=None):
140+
"""
141+
Help-function for getting a value from node.args/ kwargs, three cases:
142+
1. By position in node.args - Returns arg at given position or default_value if index is one out of bounds
143+
2. By key in node.kwargs - Returns kwarg with given key or default_value if it deos not exist
144+
3. By type in node.args - Returns first arg of args of given type. Useful for cases where arg postions may differ but types are unique.
145+
"""
146+
if isinstance(key, int):
147+
if 0 <= key < len(args):
148+
return args[key]
149+
elif key == len(args):
150+
if default_value is not None:
151+
return default_value
152+
else:
153+
raise RuntimeError(f"No defult value given for index {key}")
154+
else:
155+
raise RuntimeError(
156+
f"Out of bounds index {key} for getting value in args (of size {len(args)})"
157+
)
158+
elif isinstance(key, str):
159+
return args.get(key, default_value)
160+
elif isclass(key):
161+
for arg in args:
162+
if isinstance(arg, key):
163+
return arg
164+
if default_value is not None:
165+
return default_value
166+
else:
167+
raise RuntimeError(f"No arg of type {key}")
168+
else:
169+
raise RuntimeError("Invalid type")
170+
171+
172+
def set_node_arg(node: torch.fx.Node, i: int | str, value):
173+
"""
174+
Help-function for setting a value in node.args/ kwargs. If the index is one larger than the list size, the value is instead appended to the list.
175+
"""
176+
if isinstance(i, int):
177+
if 0 <= i < len(node.args):
178+
args = list(node.args)
179+
args[i] = value
180+
node.args = tuple(args)
181+
return
182+
elif i == len(node.args):
183+
node.args = node.args + (value,)
184+
else:
185+
raise RuntimeError(
186+
f"Out of bounds index {i} for setting value in {node} args (of size {len(node.args)})"
187+
)
188+
elif isinstance(i, str):
189+
kwargs = dict(node.kwargs)
190+
kwargs[i] = value
191+
node.kwargs = kwargs
192+
else:
193+
raise RuntimeError("Invalid type")

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-unsafe
88

99
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
1011
from executorch.exir.dialects._ops import ops as exir_ops
1112
from executorch.exir.pass_base import ExportPass
1213

@@ -42,16 +43,16 @@ def call_operator(self, op, args, kwargs, meta):
4243
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
4344
return super().call_operator(op, args, kwargs, meta)
4445

45-
x = args[0]
46-
dim = args[1]
47-
keepdim = args[2] if len(args) > 2 else False
48-
if not keepdim:
49-
return super().call_operator(op, args, kwargs, meta)
50-
# if keepdim == True and dim == [-1, -2], mean.dim can be
46+
x = get_node_arg(args, 0)
47+
dim = get_node_arg(args, 1)
48+
keepdim = get_node_arg(args, 2, False)
49+
50+
# if dim == [-1, -2], mean.dim can be
5151
# decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
5252
if dim == [-1, -2]:
5353
# Simply return the mean.dim operator for future decomposition.
5454
return super().call_operator(op, args, kwargs, meta)
55+
5556
shape = meta["val"].size()
5657
dtype = meta["val"].dtype
5758
input_shape = x.data.size()

backends/arm/_passes/decompose_var_pass.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
import torch
11+
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
1112
from executorch.exir.dialects._ops import ops as exir_ops
1213
from executorch.exir.pass_base import ExportPass
1314

@@ -53,26 +54,30 @@ def call_operator(self, op, args, kwargs, meta):
5354
torch.ops.aten.var.dim,
5455
):
5556
return super().call_operator(op, args, kwargs, meta)
56-
shape = meta["val"].size()
57+
58+
x = args[0]
59+
input_shape = x.data.size()
60+
shape = list(meta["val"].size())
61+
if shape == []:
62+
shape = [1 for _ in input_shape]
63+
5764
dtype = meta["val"].dtype
58-
dim = args[1] if len(args) > 1 else list(range(len(shape)))
65+
# Get dim from args based on argument type
66+
dim = get_node_arg(args, key=list, default_value=list(range(len(shape))))
67+
5968
if op == torch.ops.aten.var.dim:
60-
correction = args[-2]
61-
keepdim = args[-1]
69+
keepdim = get_node_arg(args, bool, False)
70+
correction = get_node_arg(args, int, 1)
6271
else:
63-
correction = kwargs["correction"]
64-
keepdim = kwargs.get("keepdim", False)
65-
if not keepdim:
66-
return super().call_operator(op, args, kwargs, meta)
72+
correction = get_node_arg(kwargs, "correction", 1)
73+
keepdim = get_node_arg(kwargs, "keepdim", False)
6774

68-
x = args[0]
69-
input_shape = x.data.size()
7075
N = 1
7176
for d in dim:
7277
N *= input_shape[d]
7378

7479
mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
75-
mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta)
80+
mean = super().call_operator(mean_op, (x, dim, True), {}, meta)
7681
diff = super().call_operator(diff_op, (x, mean), {}, meta)
7782
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
7883
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)

0 commit comments

Comments
 (0)