Skip to content

Commit 5fb6cdb

Browse files
committed
[llm] Use new API to register custom ops for llama model
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 399482c commit 5fb6cdb

File tree

18 files changed

+283
-361
lines changed

18 files changed

+283
-361
lines changed

.ci/scripts/test_llama.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ echo "Exporting ${EXPORTED_MODEL_NAME}"
119119
EXPORT_ARGS="-c stories110M.pt -p ${PARAMS} -d ${DTYPE} -n ${EXPORTED_MODEL_NAME}"
120120
if [[ "${MODE}" == "xnnpack" ]]; then
121121
EXPORT_ARGS="${EXPORT_ARGS} --pt2e_quantize xnnpack_dynamic"
122+
elif [[ "${MODE}" == "xnnpack+kv+custom" ]]; then
123+
EXPORT_ARGS="${EXPORT_ARGS} --pt2e_quantize xnnpack_dynamic --use_sdpa_with_kv_cache -kv"
122124
fi
123125
$PYTHON_EXECUTABLE -m examples.models.llama2.export_llama ${EXPORT_ARGS}
124126

.github/workflows/pull.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ jobs:
9090
matrix:
9191
dtype: [fp32]
9292
build-tool: [buck2, cmake]
93-
mode: [portable, xnnpack]
93+
mode: [portable, xnnpack, xnnpack+kv+custom]
9494
fail-fast: false
9595
with:
9696
runner: linux.2xlarge

CMakeLists.txt

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,9 @@ option(EXECUTORCH_BUILD_VULKAN "Build the Vulkan backend" OFF)
175175
#
176176
# pthreadpool: build pthreadpool library. Disable on unsupported platforms
177177
#
178-
cmake_dependent_option(EXECUTORCH_BUILD_PTHREADPOOL "Build pthreadpool library."
179-
ON "NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF)
178+
cmake_dependent_option(
179+
EXECUTORCH_BUILD_PTHREADPOOL "Build pthreadpool library." ON
180+
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF)
180181

181182
#
182183
# cpuinfo: build cpuinfo library. Disable on unsupported platforms
@@ -499,25 +500,38 @@ if(EXECUTORCH_BUILD_PYBIND)
499500
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sdk)
500501
endif()
501502

503+
# find pytorch lib, to allow pybind to take at::Tensor as input/output
504+
find_package(Torch CONFIG REQUIRED)
505+
find_library(TORCH_PYTHON_LIBRARY torch_python
506+
PATHS "${TORCH_INSTALL_PREFIX}/lib")
507+
508+
set(_dep_libs
509+
${TORCH_PYTHON_LIBRARY}
510+
bundled_program
511+
etdump
512+
executorch
513+
extension_data_loader
514+
portable_ops_lib
515+
util
516+
torch)
517+
502518
if(EXECUTORCH_BUILD_COREML)
503-
set(PYBIND_LINK_COREML "coremldelegate")
519+
list(APPEND _dep_libs coremldelegate)
504520
endif()
505521

506522
if(EXECUTORCH_BUILD_MPS)
507-
set(PYBIND_LINK_MPS "mpsdelegate")
523+
list(APPEND _dep_libs mpsdelegate)
508524
endif()
509525

510526
if(EXECUTORCH_BUILD_XNNPACK)
511-
# need to explicitly specify XNNPACK here
512-
# otherwise uses XNNPACK symbols from libtorch_cpu
513-
set(PYBIND_LINK_XNNPACK xnnpack_backend XNNPACK)
527+
# need to explicitly specify XNNPACK here otherwise uses XNNPACK symbols
528+
# from libtorch_cpu
529+
list(APPEND _dep_libs xnnpack_backend XNNPACK)
514530
endif()
515531

516-
# find pytorch lib, to allow pybind to take at::Tensor as input/output
517-
find_package(Torch CONFIG REQUIRED)
518-
find_library(TORCH_PYTHON_LIBRARY torch_python
519-
PATHS "${TORCH_INSTALL_PREFIX}/lib")
520-
532+
if(EXECUTORCH_BUILD_CUSTOM)
533+
list(APPEND _dep_libs custom_ops custom_ops_aot_lib)
534+
endif()
521535
# compile options for pybind
522536

523537
set(_pybind_compile_options -Wno-deprecated-declarations -fPIC -frtti
@@ -539,19 +553,7 @@ if(EXECUTORCH_BUILD_PYBIND)
539553
PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=portable_lib)
540554
target_include_directories(portable_lib PRIVATE ${TORCH_INCLUDE_DIRS})
541555
target_compile_options(portable_lib PUBLIC ${_pybind_compile_options})
542-
target_link_libraries(
543-
portable_lib
544-
PUBLIC ${TORCH_PYTHON_LIBRARY}
545-
bundled_program
546-
etdump
547-
executorch
548-
extension_data_loader
549-
portable_ops_lib
550-
util
551-
torch
552-
${PYBIND_LINK_COREML}
553-
${PYBIND_LINK_MPS}
554-
${PYBIND_LINK_XNNPACK})
556+
target_link_libraries(portable_lib PUBLIC ${_dep_libs})
555557

556558
install(TARGETS portable_lib
557559
LIBRARY DESTINATION executorch/extension/pybindings)

backends/xnnpack/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ target_include_directories(
7272
xnnpack_schema INTERFACE ${_xnnpack_schema__include_dir}
7373
${EXECUTORCH_ROOT}/third-party/flatbuffers/include)
7474

75+
target_compile_options(pthreadpool PUBLIC ${_common_compile_options})
7576
set(xnnpack_third_party pthreadpool cpuinfo)
7677

7778
include(cmake/Dependencies.cmake)

examples/models/llama2/CMakeLists.txt

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,24 @@ set(_common_compile_options -Wno-deprecated-declarations -fPIC)
4949
# Let files say "include <executorch/path/to/header.h>".
5050
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
5151

52-
# For some reason android build is not able to find where gflags is
53-
# and hence cannot find corresponding .cmake file
52+
# For some reason android build is not able to find where gflags is and hence
53+
# cannot find corresponding .cmake file
5454
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags)
5555
find_package(gflags REQUIRED)
5656

5757
#
5858
# llama_main: test binary to run llama, with tokenizer and sampler integrated
5959
#
60-
add_executable(llama_main main.cpp
61-
${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/threadpool/cpuinfo_utils.cpp)
60+
add_executable(
61+
llama_main
62+
main.cpp
63+
${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/threadpool/cpuinfo_utils.cpp
64+
)
6265
if(CMAKE_BUILD_TYPE EQUAL "RELEASE")
6366
target_link_options(llama_main PRIVATE "LINKER:--gc-sections")
6467
endif()
6568

66-
# find `executorch` libraries
67-
# Same as for gflags
69+
# find `executorch` libraries Same as for gflags
6870
set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../lib/cmake/ExecuTorch)
6971
find_package(executorch CONFIG REQUIRED)
7072
if(CMAKE_TOOLCHAIN_IOS OR ANDROID)
@@ -77,23 +79,35 @@ add_subdirectory(custom_ops)
7779
# llama_runner library
7880
add_subdirectory(runner)
7981

80-
target_include_directories(llama_main PUBLIC
81-
${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/third-party/cpuinfo/include)
82-
target_include_directories(llama_main PUBLIC
83-
${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/third-party/pthreadpool/include)
82+
target_include_directories(
83+
llama_main
84+
PUBLIC
85+
${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/third-party/cpuinfo/include
86+
)
87+
target_include_directories(
88+
llama_main
89+
PUBLIC
90+
${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/third-party/pthreadpool/include
91+
)
8492

8593
set(link_libraries)
8694

8795
if(EXECUTORCH_BUILD_OPTIMIZED)
88-
list(APPEND link_libraries optimized_native_cpu_ops_lib optimized_kernels
89-
portable_kernels cpublas eigen_blas)
96+
list(
97+
APPEND
98+
link_libraries
99+
optimized_native_cpu_ops_lib
100+
optimized_kernels
101+
portable_kernels
102+
cpublas
103+
eigen_blas)
90104
target_link_options_shared_lib(optimized_native_cpu_ops_lib)
91105
else()
92106
list(APPEND link_libraries portable_ops_lib portable_kernels)
93107
target_link_options_shared_lib(portable_ops_lib)
94108
endif()
95109

96-
target_link_libraries(llama_main PUBLIC gflags llama_runner custom_ops_lib)
110+
target_link_libraries(llama_main PUBLIC gflags llama_runner custom_ops)
97111

98112
# XNNPACK pthreadpool cpuinfo
99113
if(TARGET xnnpack_backend)
@@ -114,14 +128,13 @@ if(TARGET qnn_executorch_backend)
114128
target_link_options_shared_lib(qnn_executorch_backend)
115129
endif()
116130

117-
# This one is needed for cpuinfo where it uses android
118-
# specific log lib
131+
# This one is needed for cpuinfo where it uses android specific log lib
119132
if(ANDROID)
120133
list(APPEND link_libraries log)
121134
endif()
122135

123136
target_compile_options(llama_main PUBLIC ${_common_compile_options}
124-
-DET_USE_THREADPOOL)
137+
-DET_USE_THREADPOOL)
125138
target_link_libraries(llama_main PUBLIC ${link_libraries})
126139

127140
if(APPLE)

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ runtime.python_binary(
5252
main_module = "executorch.examples.models.llama2.export_llama",
5353
# visibility = ["//executorch/examples/..."],
5454
preload_deps = [
55+
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_lib",
5556
"//executorch/kernels/quantized:aot_lib",
5657
],
5758
deps = [

examples/models/llama2/custom_ops/CMakeLists.txt

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,12 @@ include(${EXECUTORCH_SRCS_FILE})
4444
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
4545

4646
# Custom op libraries
47-
set(custom_ops_libs extension_module)
47+
set(custom_ops_libs extension_module executorch)
4848
list(APPEND custom_ops_libs pthreadpool)
4949
list(APPEND custom_ops_libs cpuinfo)
5050
list(APPEND custom_ops_libs cpublas)
5151
list(APPEND custom_ops_libs eigen_blas)
5252

53-
# Generate C++ bindings to register kernels into both PyTorch (for AOT) and
54-
# Executorch (for runtime). Here select all ops in optimized.yaml
55-
set(_yaml "${CMAKE_CURRENT_LIST_DIR}/custom_ops.yaml")
56-
gen_selected_ops("${_yaml}" "" "")
57-
58-
generate_bindings_for_kernels(FUNCTIONS_YAML
59-
${CMAKE_CURRENT_SOURCE_DIR}/custom_ops.yaml)
60-
message("Generated files ${gen_command_sources}")
61-
6253
list(TRANSFORM _custom_ops__srcs PREPEND "${EXECUTORCH_ROOT}/")
6354

6455
# TODO: Consider moving xnnpack/threadpool in a separate lib since it's now used
@@ -82,7 +73,14 @@ target_link_libraries(custom_ops PUBLIC ${custom_ops_libs})
8273
target_compile_options(custom_ops PUBLIC ${_common_compile_options}
8374
-DET_USE_THREADPOOL)
8475

85-
# Build a library for _custom_ops_srcs
86-
#
87-
# custom_ops_lib: Register optimized ops kernels into Executorch runtime
88-
gen_operators_lib("custom_ops_lib" KERNEL_LIBS custom_ops DEPS executorch)
76+
# Add a AOT library
77+
find_package(Torch CONFIG REQUIRED)
78+
add_library(custom_ops_aot_lib SHARED
79+
${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp)
80+
target_include_directories(custom_ops_aot_lib
81+
PUBLIC "${_common_include_directories}")
82+
target_include_directories(
83+
custom_ops_aot_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include")
84+
target_link_libraries(custom_ops_aot_lib PUBLIC custom_ops torch)
85+
86+
install(TARGETS custom_ops custom_ops_aot_lib DESTINATION lib)

examples/models/llama2/custom_ops/custom_ops.yaml

Lines changed: 0 additions & 14 deletions
This file was deleted.

examples/models/llama2/custom_ops/op_sdpa.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
77
*/
8-
9-
#include <executorch/runtime/kernel/kernel_includes.h>
8+
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>
109

1110
#include <executorch/kernels/optimized/blas/CPUBlas.h>
1211
#include <executorch/kernels/optimized/vec/functional.h>
@@ -22,6 +21,7 @@
2221
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
2322
#include <executorch/extension/parallel/thread_parallel.h>
2423
#endif
24+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
2525

2626
namespace torch {
2727
namespace executor {
@@ -843,3 +843,8 @@ Tensor& sdpa_with_kv_cache_out(
843843
} // namespace native
844844
} // namespace executor
845845
} // namespace torch
846+
847+
EXECUTORCH_LIBRARY(
848+
llama,
849+
"sdpa_with_kv_cache.out",
850+
torch::executor::native::sdpa_with_kv_cache_out);
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <executorch/runtime/kernel/kernel_includes.h>
9+
10+
namespace torch {
11+
namespace executor {
12+
13+
namespace native {
14+
15+
Tensor& sdpa_with_kv_cache_out(
16+
RuntimeContext& ctx,
17+
const Tensor& q_projected,
18+
const Tensor& k_projected,
19+
const Tensor& v_projected,
20+
Tensor& key_cache,
21+
Tensor& value_cache,
22+
const int64_t start_pos,
23+
const int64_t seq_len,
24+
const optional<Tensor>& attn_mask,
25+
const double dropout_p,
26+
const bool is_causal,
27+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
28+
const optional<double> scale,
29+
Tensor& output);
30+
31+
Tensor& flash_attention_kernel_out(
32+
RuntimeContext& ctx,
33+
const Tensor& query,
34+
const Tensor& key,
35+
const Tensor& value,
36+
const optional<Tensor>& attn_mask,
37+
const double dropout_p,
38+
const bool is_causal,
39+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
40+
const optional<double> scale,
41+
Tensor& output);
42+
43+
} // namespace native
44+
} // namespace executor
45+
} // namespace torch

0 commit comments

Comments
 (0)