-
Notifications
You must be signed in to change notification settings - Fork 608
Use new API to register custom ExecuTorch kernels into ATen #2937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,6 +144,8 @@ option(EXECUTORCH_BUILD_COREML "Build the Core ML backend" OFF) | |
|
||
option(EXECUTORCH_BUILD_CUSTOM "Build the custom kernels" OFF) | ||
|
||
option(EXECUTORCH_BUILD_CUSTOM_OPS_AOT "Build the custom ops lib for AOT" OFF) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This name is too similar to Really, |
||
|
||
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "Build the Data Loader extension" | ||
OFF) | ||
|
||
|
@@ -185,12 +187,19 @@ cmake_dependent_option( | |
cmake_dependent_option(EXECUTORCH_BUILD_CPUINFO "Build cpuinfo library." ON | ||
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF) | ||
|
||
if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT) | ||
set(EXECUTORCH_BUILD_CUSTOM ON) | ||
endif() | ||
|
||
if(EXECUTORCH_BUILD_CUSTOM) | ||
set(EXECUTORCH_BUILD_OPTIMIZED ON) | ||
endif() | ||
|
||
if(EXECUTORCH_BUILD_CPUINFO) | ||
# --- cpuinfo | ||
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is an internal temp var it should have a name like |
||
${CMAKE_POSITION_INDEPENDENT_CODE}) | ||
set(CMAKE_POSITION_INDEPENDENT_CODE ON) | ||
set(CPUINFO_SOURCE_DIR "backends/xnnpack/third-party/cpuinfo") | ||
set(CPUINFO_BUILD_TOOLS | ||
OFF | ||
|
@@ -212,10 +221,15 @@ if(EXECUTORCH_BUILD_CPUINFO) | |
CACHE STRING "") | ||
set(CLOG_SOURCE_DIR "${CPUINFO_SOURCE_DIR}/deps/clog") | ||
add_subdirectory("${CPUINFO_SOURCE_DIR}") | ||
set(CMAKE_POSITION_INDEPENDENT_CODE | ||
${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG}) | ||
endif() | ||
|
||
if(EXECUTORCH_BUILD_PTHREADPOOL) | ||
# --- pthreadpool | ||
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG | ||
${CMAKE_POSITION_INDEPENDENT_CODE}) | ||
set(CMAKE_POSITION_INDEPENDENT_CODE ON) | ||
set(PTHREADPOOL_SOURCE_DIR "backends/xnnpack/third-party/pthreadpool") | ||
set(PTHREADPOOL_BUILD_TESTS | ||
OFF | ||
|
@@ -235,6 +249,8 @@ if(EXECUTORCH_BUILD_PTHREADPOOL) | |
CACHE STRING "") | ||
endif() | ||
add_subdirectory("${PTHREADPOOL_SOURCE_DIR}") | ||
set(CMAKE_POSITION_INDEPENDENT_CODE | ||
${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG}) | ||
endif() | ||
|
||
if(NOT PYTHON_EXECUTABLE) | ||
|
@@ -546,6 +562,9 @@ if(EXECUTORCH_BUILD_PYBIND) | |
list(APPEND _dep_libs custom_ops) | ||
endif() | ||
|
||
if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT) | ||
list(APPEND _dep_libs custom_ops_aot_lib) | ||
endif() | ||
# compile options for pybind | ||
|
||
set(_pybind_compile_options -Wno-deprecated-declarations -fPIC -frtti | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ if(NOT TORCH_ROOT) | |
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch) | ||
endif() | ||
|
||
set(_common_compile_options -Wno-deprecated-declarations) | ||
set(_common_compile_options -Wno-deprecated-declarations -fPIC) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a comment explaining why this sets There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I try to avoid using |
||
|
||
include(${EXECUTORCH_ROOT}/build/Utils.cmake) | ||
include(${EXECUTORCH_ROOT}/build/Codegen.cmake) | ||
|
@@ -44,7 +44,7 @@ include(${EXECUTORCH_SRCS_FILE}) | |
set(_common_include_directories ${EXECUTORCH_ROOT}/..) | ||
|
||
# Custom op libraries | ||
set(custom_ops_libs extension_module) | ||
set(custom_ops_libs executorch_no_prim_ops) | ||
list(APPEND custom_ops_libs pthreadpool) | ||
list(APPEND custom_ops_libs cpuinfo) | ||
list(APPEND custom_ops_libs cpublas) | ||
|
@@ -76,3 +76,19 @@ target_compile_options(custom_ops PUBLIC ${_common_compile_options} | |
-DET_USE_THREADPOOL) | ||
|
||
install(TARGETS custom_ops DESTINATION lib) | ||
|
||
if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT) | ||
# Add a AOT library | ||
find_package(Torch CONFIG REQUIRED) | ||
add_library(custom_ops_aot_lib SHARED | ||
${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp) | ||
target_include_directories(custom_ops_aot_lib | ||
PUBLIC "${_common_include_directories}") | ||
target_include_directories( | ||
custom_ops_aot_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include") | ||
target_link_libraries(custom_ops_aot_lib PUBLIC custom_ops torch) | ||
target_compile_options(custom_ops_aot_lib PUBLIC -Wno-deprecated-declarations | ||
-fPIC -frtti -fexceptions) | ||
|
||
install(TARGETS custom_ops_aot_lib DESTINATION lib) | ||
endif() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h> | ||
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h> | ||
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h> | ||
|
||
#include <torch/library.h> | ||
|
||
namespace torch { | ||
namespace executor { | ||
|
||
namespace native { | ||
|
||
Tensor& sdpa_with_kv_cache_out_no_context( | ||
const Tensor& q_projected, | ||
const Tensor& k_projected, | ||
const Tensor& v_projected, | ||
Tensor& key_cache, | ||
Tensor& value_cache, | ||
const int64_t start_pos, | ||
const int64_t seq_len, | ||
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue | ||
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy | ||
const optional<Tensor> attn_mask, | ||
const double dropout_p, | ||
const bool is_causal, | ||
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy | ||
const optional<double> scale, | ||
Tensor& output) { | ||
exec_aten::RuntimeContext context{}; | ||
return torch::executor::native::sdpa_with_kv_cache_out( | ||
context, | ||
q_projected, | ||
k_projected, | ||
v_projected, | ||
key_cache, | ||
value_cache, | ||
start_pos, | ||
seq_len, | ||
attn_mask, | ||
dropout_p, | ||
is_causal, | ||
scale, | ||
output); | ||
} | ||
|
||
at::Tensor sdpa_with_kv_cache_aten( | ||
const at::Tensor& q_projected, | ||
const at::Tensor& k_projected, | ||
const at::Tensor& v_projected, | ||
at::Tensor& key_cache, | ||
at::Tensor& value_cache, | ||
const int64_t start_pos, | ||
const int64_t seq_len, | ||
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue | ||
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy | ||
const c10::optional<at::Tensor> attn_mask, | ||
const double dropout_p, | ||
const bool is_causal, | ||
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy | ||
const c10::optional<double> scale) { | ||
auto output = at::empty_like(q_projected); | ||
WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11) | ||
(q_projected, | ||
k_projected, | ||
v_projected, | ||
key_cache, | ||
value_cache, | ||
start_pos, | ||
seq_len, | ||
attn_mask, | ||
dropout_p, | ||
is_causal, | ||
scale, | ||
output); | ||
return output; | ||
} | ||
|
||
} // namespace native | ||
} // namespace executor | ||
} // namespace torch | ||
|
||
TORCH_LIBRARY(llama, m) { | ||
m.def( | ||
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " | ||
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " | ||
"float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor"); | ||
m.def( | ||
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " | ||
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " | ||
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"); | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { | ||
m.impl( | ||
"sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten); | ||
m.impl( | ||
"sdpa_with_kv_cache.out", | ||
WRAP_TO_ATEN( | ||
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name will be too vague over time: are we only going to have one library with custom ops? Or is this llama-specifc? If it is just one global library, how do we decide which kinds of custom ops get to live in it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah i agree. this custom ops were examples/models/llama specific at the time it was added. Though things have changed since.