Skip to content

Commit 3881bff

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Cmake sdpa custom op (#2671)
Summary: Pull Request resolved: #2671 Enable `--use_sdpa_with_kv_cache` with Cmake. ghstack-source-id: 220378207 exported-using-ghexport Pre-existing lint failures bypass-github-export-checks bypass-github-pytorch-ci-checks Reviewed By: digantdesai, mcr229 Differential Revision: D55344016 fbshipit-source-id: 7dafab86dc66da85c7a988a728bb5cce5382e3d6
1 parent 33e13a2 commit 3881bff

File tree

7 files changed

+138
-27
lines changed

7 files changed

+138
-27
lines changed

CMakeLists.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,26 @@ endif()
4545
if(NOT CMAKE_BUILD_TYPE)
4646
set(CMAKE_BUILD_TYPE Debug)
4747
endif()
48+
49+
# --- cpuinfo
50+
set(CPUINFO_SOURCE_DIR "backends/xnnpack/third-party/cpuinfo")
51+
set(CPUINFO_BUILD_TOOLS OFF CACHE BOOL "")
52+
set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "")
53+
set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "")
54+
set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "")
55+
set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "")
56+
set(CPUINFO_LOG_LEVEL "error" CACHE STRING "")
57+
set(CLOG_SOURCE_DIR "${CPUINFO_SOURCE_DIR}/deps/clog")
58+
add_subdirectory("${CPUINFO_SOURCE_DIR}")
59+
60+
# --- pthreadpool
61+
set(PTHREADPOOL_SOURCE_DIR "backends/xnnpack/third-party/pthreadpool")
62+
set(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "")
63+
set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "")
64+
set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "")
65+
set(PTHREADPOOL_ALLOW_DEPRECATED_API ON CACHE BOOL "")
66+
add_subdirectory("${PTHREADPOOL_SOURCE_DIR}")
67+
4868
# ------------------------------ OPTIONS -------------------------------------
4969
# WARNING: Please don't add example specific options in this CMakeLists.txt.
5070
# Instead please use `find_package(executorch REQUIRED)` in the example

backends/xnnpack/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ target_link_libraries(xnnpack_backend
8787
target_include_directories(xnnpack_backend
8888
PUBLIC ${_common_include_directories})
8989
target_include_directories(xnnpack_backend PUBLIC ${XNNPACK_INCLUDE_DIR})
90+
target_include_directories(xnnpack_backend PUBLIC
91+
${CMAKE_CURRENT_SOURCE_DIR}/third-party/pthreadpool/include)
92+
target_include_directories(xnnpack_backend PUBLIC
93+
${CMAKE_CURRENT_SOURCE_DIR}/third-party/cpuinfo/include)
9094
target_compile_options(xnnpack_backend PUBLIC ${_common_compile_options})
9195
target_link_options_shared_lib(xnnpack_backend)
9296

@@ -101,7 +105,8 @@ if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*iOS\.cmake$")
101105
list(TRANSFORM _xnn_executor_runner__srcs PREPEND "${EXECUTORCH_ROOT}/")
102106
add_executable(xnn_executor_runner ${_xnn_executor_runner__srcs})
103107
target_link_libraries(xnn_executor_runner
104-
xnnpack_backend gflags portable_ops_lib)
108+
xnnpack_backend gflags portable_ops_lib
109+
pthreadpool cpuinfo)
105110
target_compile_options(xnn_executor_runner PUBLIC ${_common_compile_options})
106111
endif()
107112

backends/xnnpack/cmake/Dependencies.cmake

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,6 @@
88
# https://github.com/pytorch/pytorch/blob/main/cmake/Dependencies.cmake
99
set(THIRD_PARTY_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/third-party")
1010

11-
# --- cpuinfo
12-
set(CPUINFO_SOURCE_DIR "${THIRD_PARTY_ROOT}/cpuinfo")
13-
set(CPUINFO_BUILD_TOOLS OFF CACHE BOOL "")
14-
set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "")
15-
set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "")
16-
set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "")
17-
set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "")
18-
set(CPUINFO_LOG_LEVEL "error" CACHE STRING "")
19-
set(CLOG_SOURCE_DIR "${CPUINFO_SOURCE_DIR}/deps/clog")
20-
add_subdirectory("${CPUINFO_SOURCE_DIR}")
21-
list(APPEND xnnpack_third_party cpuinfo)
22-
23-
# --- pthreadpool
24-
set(PTHREADPOOL_SOURCE_DIR "${THIRD_PARTY_ROOT}/pthreadpool")
25-
set(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "")
26-
set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "")
27-
set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "")
28-
set(PTHREADPOOL_ALLOW_DEPRECATED_API ON CACHE BOOL "")
29-
add_subdirectory("${PTHREADPOOL_SOURCE_DIR}")
30-
list(APPEND xnnpack_third_party pthreadpool)
31-
3211
# --- XNNPACK
3312
set(XNNPACK_SOURCE_DIR "${THIRD_PARTY_ROOT}/XNNPACK")
3413
set(XNNPACK_INCLUDE_DIR "${XNNPACK_SOURCE_DIR}/include")

build/cmake_deps.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,21 @@ filters = [
265265
]
266266
# ---------------------------------- Vulkan end -----------------------------------
267267
# ---------------------------------- LLama start ----------------------------------
268+
[targets.custom_ops]
269+
buck_targets = [
270+
"//examples/models/llama2/custom_ops:custom_ops",
271+
]
272+
filters = [
273+
".cpp$",
274+
]
275+
excludes = [
276+
"^codegen",
277+
# Exclude blas, since it's built as a separate target.
278+
# "^kernels/optimized/blas",
279+
]
280+
deps = [
281+
"executorch",
282+
]
268283

269284
[targets.llama_runner]
270285
buck_targets = [
@@ -277,6 +292,7 @@ excludes = [
277292
"^codegen",
278293
]
279294
deps = [
295+
"custom_ops",
280296
"executorch",
281297
"extension_data_loader",
282298
"extension_module",

examples/models/llama2/CMakeLists.txt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,26 @@ find_package(gflags REQUIRED PATHS
4545
#
4646
# llama_main: test binary to run llama, with tokenizer and sampler integrated
4747
#
48-
add_executable(llama_main main.cpp)
48+
add_executable(llama_main main.cpp
49+
${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/threadpool/cpuinfo_utils.cpp)
4950
if(CMAKE_BUILD_TYPE EQUAL "RELEASE")
5051
target_link_options(llama_main PRIVATE "LINKER:--gc-sections")
5152
endif()
5253

5354
# find `executorch` libraries
5455
find_package(executorch CONFIG REQUIRED)
5556

57+
# custom ops library
58+
add_subdirectory(custom_ops)
59+
5660
# llama_runner library
5761
add_subdirectory(runner)
5862

63+
target_include_directories(llama_main PUBLIC
64+
${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/third-party/cpuinfo/include)
65+
target_include_directories(llama_main PUBLIC
66+
${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/third-party/pthreadpool/include)
67+
5968
set(link_libraries)
6069

6170
if(EXECUTORCH_BUILD_OPTIMIZED)
@@ -67,7 +76,7 @@ else()
6776
target_link_options_shared_lib(portable_ops_lib)
6877
endif()
6978

70-
target_link_libraries(llama_main PUBLIC gflags llama_runner)
79+
target_link_libraries(llama_main PUBLIC gflags llama_runner custom_ops_lib)
7180

7281
# XNNPACK pthreadpool cpuinfo
7382
if(TARGET xnnpack_backend)
@@ -82,7 +91,8 @@ if(TARGET vulkan_backend)
8291
target_link_options_shared_lib(vulkan_backend)
8392
endif()
8493

85-
target_compile_options(llama_main PUBLIC ${_common_compile_options})
94+
target_compile_options(llama_main PUBLIC ${_common_compile_options}
95+
-DET_USE_THREADPOOL)
8696
target_link_libraries(llama_main PUBLIC ${link_libraries})
8797

8898
if(APPLE)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
cmake_minimum_required(VERSION 3.19)
9+
10+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
11+
if(NOT CMAKE_CXX_STANDARD)
12+
set(CMAKE_CXX_STANDARD 17)
13+
endif()
14+
15+
if(NOT PYTHON_EXECUTABLE)
16+
set(PYTHON_EXECUTABLE python3)
17+
endif()
18+
19+
# Source root directory for executorch.
20+
if(NOT EXECUTORCH_ROOT)
21+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
22+
endif()
23+
24+
# Source root directory for pytorch.
25+
if(NOT TORCH_ROOT)
26+
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
27+
endif()
28+
29+
set(_common_compile_options -Wno-deprecated-declarations)
30+
31+
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
32+
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
33+
34+
#
35+
# The `_<target>_srcs` lists are defined by including ${EXECUTORCH_SRCS_FILE}.
36+
#
37+
set(EXECUTORCH_SRCS_FILE
38+
"${CMAKE_CURRENT_BINARY_DIR}/../../../../executorch_srcs.cmake")
39+
40+
extract_sources(${EXECUTORCH_SRCS_FILE})
41+
42+
include(${EXECUTORCH_SRCS_FILE})
43+
44+
# Let files say "include <executorch/path/to/header.h>".
45+
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
46+
47+
48+
# Custom op libraries
49+
set(custom_ops_libs extension_module)
50+
list(APPEND custom_ops_libs pthreadpool)
51+
list(APPEND custom_ops_libs cpuinfo)
52+
53+
# Generate C++ bindings to register kernels into both PyTorch (for AOT) and
54+
# Executorch (for runtime). Here select all ops in optimized.yaml
55+
set(_yaml "${CMAKE_CURRENT_LIST_DIR}/custom_ops.yaml")
56+
gen_selected_ops("${_yaml}" "" "")
57+
58+
generate_bindings_for_kernels(
59+
FUNCTIONS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/custom_ops.yaml)
60+
message("Generated files ${gen_command_sources}")
61+
62+
list(TRANSFORM _custom_ops__srcs PREPEND "${EXECUTORCH_ROOT}/")
63+
add_library(custom_ops ${_custom_ops__srcs})
64+
65+
target_include_directories(custom_ops PUBLIC "${_common_include_directories}")
66+
target_include_directories(custom_ops PRIVATE
67+
"${CMAKE_CURRENT_BINARY_DIR}/../../../../include")
68+
target_link_libraries(custom_ops PRIVATE ${custom_ops_libs})
69+
70+
target_compile_options(custom_ops PUBLIC ${_common_compile_options}
71+
-DET_USE_THREADPOOL)
72+
73+
# Build a library for _custom_ops_srcs
74+
#
75+
# custom_ops_lib: Register optimized ops kernels into Executorch runtime
76+
gen_operators_lib(
77+
"custom_ops_lib"
78+
KERNEL_LIBS custom_ops
79+
DEPS executorch)

examples/models/llama2/runner/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@ else()
4747
add_library(llama_runner SHARED ${_llama_runner__srcs})
4848
endif()
4949

50-
set(llama_runner_deps executorch extension_module extension_data_loader)
50+
set(llama_runner_deps executorch extension_module extension_data_loader
51+
custom_ops)
5152

5253
target_link_libraries(
5354
llama_runner PUBLIC ${llama_runner_deps})
5455

5556
target_include_directories(llama_runner
56-
INTERFACE ${_common_include_directories} ${EXECUTORCH_ROOT})
57+
INTERFACE ${_common_include_directories}
58+
${EXECUTORCH_ROOT})

0 commit comments

Comments
 (0)