Skip to content

Use new API to register custom ExecuTorch kernels into ATen #2937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ option(EXECUTORCH_BUILD_COREML "Build the Core ML backend" OFF)

option(EXECUTORCH_BUILD_CUSTOM "Build the custom kernels" OFF)

option(EXECUTORCH_BUILD_CUSTOM_OPS_AOT "Build the custom ops lib for AOT" OFF)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name will be too vague over time: are we only going to have one library with custom ops? Or is this llama-specifc? If it is just one global library, how do we decide which kinds of custom ops get to live in it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i agree. this custom ops were examples/models/llama specific at the time it was added. Though things have changed since.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is too similar to EXECUTORCH_BUILD_CUSTOM. If you ask a random person what the different is between EXECUTORCH_BUILD_CUSTOM and EXECUTORCH_BUILD_CUSTOM_OPS_AOT they probably wouldn't have a good answer.

Really, EXECUTORCH_BUILD_CUSTOM should probably have a more descriptive name. If that's too disruptive, please update the help strings for both options to make it more clear how they differ, and help users understand which one to enable and when.


option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "Build the Data Loader extension"
OFF)

Expand Down Expand Up @@ -185,12 +187,19 @@ cmake_dependent_option(
cmake_dependent_option(EXECUTORCH_BUILD_CPUINFO "Build cpuinfo library." ON
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF)

if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT)
set(EXECUTORCH_BUILD_CUSTOM ON)
endif()

if(EXECUTORCH_BUILD_CUSTOM)
set(EXECUTORCH_BUILD_OPTIMIZED ON)
endif()

if(EXECUTORCH_BUILD_CPUINFO)
# --- cpuinfo
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is an internal temp var it should have a name like _original_cmake_position_independent_code_flag. Similar below.

${CMAKE_POSITION_INDEPENDENT_CODE})
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CPUINFO_SOURCE_DIR "backends/xnnpack/third-party/cpuinfo")
set(CPUINFO_BUILD_TOOLS
OFF
Expand All @@ -212,10 +221,15 @@ if(EXECUTORCH_BUILD_CPUINFO)
CACHE STRING "")
set(CLOG_SOURCE_DIR "${CPUINFO_SOURCE_DIR}/deps/clog")
add_subdirectory("${CPUINFO_SOURCE_DIR}")
set(CMAKE_POSITION_INDEPENDENT_CODE
${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG})
endif()

if(EXECUTORCH_BUILD_PTHREADPOOL)
# --- pthreadpool
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG
${CMAKE_POSITION_INDEPENDENT_CODE})
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(PTHREADPOOL_SOURCE_DIR "backends/xnnpack/third-party/pthreadpool")
set(PTHREADPOOL_BUILD_TESTS
OFF
Expand All @@ -235,6 +249,8 @@ if(EXECUTORCH_BUILD_PTHREADPOOL)
CACHE STRING "")
endif()
add_subdirectory("${PTHREADPOOL_SOURCE_DIR}")
set(CMAKE_POSITION_INDEPENDENT_CODE
${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG})
endif()

if(NOT PYTHON_EXECUTABLE)
Expand Down Expand Up @@ -546,6 +562,9 @@ if(EXECUTORCH_BUILD_PYBIND)
list(APPEND _dep_libs custom_ops)
endif()

if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT)
list(APPEND _dep_libs custom_ops_aot_lib)
endif()
# compile options for pybind

set(_pybind_compile_options -Wno-deprecated-declarations -fPIC -frtti
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ add_library(xnnpack_backend STATIC ${_xnnpack_backend__srcs})
target_link_libraries(xnnpack_backend
PRIVATE
${xnnpack_third_party}
executorch
executorch_no_prim_ops
xnnpack_schema)

target_include_directories(xnnpack_backend
Expand Down
3 changes: 2 additions & 1 deletion examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ runtime.python_library(
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama2/custom_ops:llama_custom_ops_aot_lib",
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py",
],
)

Expand Down Expand Up @@ -52,6 +52,7 @@ runtime.python_binary(
main_module = "executorch.examples.models.llama2.export_llama",
# visibility = ["//executorch/examples/..."],
preload_deps = [
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_lib",
"//executorch/kernels/quantized:aot_lib",
],
deps = [
Expand Down
20 changes: 18 additions & 2 deletions examples/models/llama2/custom_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ if(NOT TORCH_ROOT)
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
endif()

set(_common_compile_options -Wno-deprecated-declarations)
set(_common_compile_options -Wno-deprecated-declarations -fPIC)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment explaining why this sets -fPIC manually instead of using CMAKE_POSITION_INDEPENDENT_CODE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I try to avoid using CMAKE_POSITION_INDEPENDENT_CODE as much as possible since it's a global variable. We have to use it for third-party libs because we can't change their CMakeLists.txt but for this one we can tweak the compile options for this specific lib.


include(${EXECUTORCH_ROOT}/build/Utils.cmake)
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
Expand All @@ -44,7 +44,7 @@ include(${EXECUTORCH_SRCS_FILE})
set(_common_include_directories ${EXECUTORCH_ROOT}/..)

# Custom op libraries
set(custom_ops_libs extension_module)
set(custom_ops_libs executorch_no_prim_ops)
list(APPEND custom_ops_libs pthreadpool)
list(APPEND custom_ops_libs cpuinfo)
list(APPEND custom_ops_libs cpublas)
Expand Down Expand Up @@ -76,3 +76,19 @@ target_compile_options(custom_ops PUBLIC ${_common_compile_options}
-DET_USE_THREADPOOL)

install(TARGETS custom_ops DESTINATION lib)

if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT)
# Add a AOT library
find_package(Torch CONFIG REQUIRED)
add_library(custom_ops_aot_lib SHARED
${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp)
target_include_directories(custom_ops_aot_lib
PUBLIC "${_common_include_directories}")
target_include_directories(
custom_ops_aot_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include")
target_link_libraries(custom_ops_aot_lib PUBLIC custom_ops torch)
target_compile_options(custom_ops_aot_lib PUBLIC -Wno-deprecated-declarations
-fPIC -frtti -fexceptions)

install(TARGETS custom_ops_aot_lib DESTINATION lib)
endif()
107 changes: 107 additions & 0 deletions examples/models/llama2/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>

#include <torch/library.h>

namespace torch {
namespace executor {

namespace native {

Tensor& sdpa_with_kv_cache_out_no_context(
const Tensor& q_projected,
const Tensor& k_projected,
const Tensor& v_projected,
Tensor& key_cache,
Tensor& value_cache,
const int64_t start_pos,
const int64_t seq_len,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
exec_aten::RuntimeContext context{};
return torch::executor::native::sdpa_with_kv_cache_out(
context,
q_projected,
k_projected,
v_projected,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
dropout_p,
is_causal,
scale,
output);
}

at::Tensor sdpa_with_kv_cache_aten(
const at::Tensor& q_projected,
const at::Tensor& k_projected,
const at::Tensor& v_projected,
at::Tensor& key_cache,
at::Tensor& value_cache,
const int64_t start_pos,
const int64_t seq_len,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const c10::optional<at::Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const c10::optional<double> scale) {
auto output = at::empty_like(q_projected);
WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11)
(q_projected,
k_projected,
v_projected,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
dropout_p,
is_causal,
scale,
output);
return output;
}

} // namespace native
} // namespace executor
} // namespace torch

TORCH_LIBRARY(llama, m) {
m.def(
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor");
m.def(
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
}

TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl(
"sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten);
m.impl(
"sdpa_with_kv_cache.out",
WRAP_TO_ATEN(
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
}
111 changes: 20 additions & 91 deletions examples/models/llama2/custom_ops/sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,29 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Import custom op defined in op_sdpa_aot.cpp. Those ops are using PyTorch
# C++ APIs for registration so here we need to import the shared library.
# This is only needed for OSS.

import logging
from pathlib import Path

import torch
from torch.library import impl, impl_abstract

custom_ops_lib = torch.library.Library("llama", "DEF")
custom_ops_lib.define(
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor"
)
from torch.library import impl

custom_ops_lib.define(
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"
)
try:
op = torch.ops.llama.sdpa_with_kv_cache.default
assert op is not None
except:
libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*"))
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
logging.info(f"Loading custom ops library: {libs[0]}")
torch.ops.load_library(libs[0])
op = torch.ops.llama.sdpa_with_kv_cache.default
assert op is not None

custom_ops_lib = torch.library.Library("llama", "IMPL")


def _validate_params(
Expand Down Expand Up @@ -118,82 +126,3 @@ def sdpa_with_kv_cache_meta(
)

return torch.empty_like(query)


@impl(custom_ops_lib, "sdpa_with_kv_cache", "CompositeExplicitAutograd")
def sdpa_with_kv_cache(
query,
key,
value,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask=None,
drpout_p=0.0,
is_causal=False,
scale=None,
):
_validate_params(
query,
key,
value,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
drpout_p,
is_causal,
scale,
)

if attn_mask is not None:
attn_mask = attn_mask[start_pos].view((1, -1))
attn_mask = attn_mask[:, : start_pos + seq_len]
q = query.transpose(1, 2)
key_cache[:, start_pos] = key
value_cache[:, start_pos] = value

sliced_k_cache = key_cache
sliced_v_cache = value_cache
sliced_k_cache = sliced_k_cache[:, : start_pos + seq_len, :, :]
sliced_v_cache = sliced_v_cache[:, : start_pos + seq_len, :, :]
sliced_k_cache = sliced_k_cache.transpose(1, 2)
sliced_v_cache = sliced_v_cache.transpose(1, 2)
out = torch.nn.functional.scaled_dot_product_attention(
q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask
)
out = out.transpose(1, 2)
return out


@impl_abstract("llama::sdpa_with_kv_cache.out")
def sdpa_with_kv_cache_out(
query,
key,
value,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
drpout_p,
is_causal,
scale,
out,
):
out = sdpa_with_kv_cache_meta(
query,
key,
value,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
drpout_p,
is_causal,
scale,
)
return out
Loading