Skip to content

Commit 4559793

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Add Vulkan backend to llama demo binary (#2608)
Summary: Pull Request resolved: #2608 ## Context Cleans up Vulkan backend CMake build so that `vulkan_backend` target is a single static library will all required symbols. Then add vulkan backend to llama runner binary. imported-using-ghimport Test Plan: Imported from OSS Reviewed By: mcr229, cbilgin Differential Revision: D55272794 Pulled By: SS-JIA fbshipit-source-id: 16074d32315a01df0a2fd70311ae98c07da12dc9
1 parent 1d34cee commit 4559793

File tree

5 files changed

+58
-42
lines changed

5 files changed

+58
-42
lines changed

backends/vulkan/CMakeLists.txt

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,23 @@ set(COMMON_INCLUDES ${VULKAN_API_HEADERS} ${EXECUTORCH_ROOT}/..)
4949
# vulkan_graph_lib
5050

5151
file(GLOB_RECURSE vulkan_graph_cpp ${RUNTIME_PATH}/graph/*)
52+
list(APPEND vulkan_graph_cpp ${vulkan_api_cpp})
5253

5354
add_library(vulkan_graph_lib STATIC ${vulkan_graph_cpp})
5455
target_include_directories(vulkan_graph_lib PRIVATE ${COMMON_INCLUDES})
55-
target_link_libraries(${LIBRARY_NAME} vulkan_api_lib)
5656
target_compile_options(vulkan_graph_lib PRIVATE ${VULKAN_CXX_FLAGS})
5757
# Link this library with --whole-archive due to dynamic operator registrations
5858
target_link_options_shared_lib(vulkan_graph_lib)
5959

60-
# Due to dynamic registrations, these libraries must be explicitly linked
61-
set(VULKAN_STANDARD_OPS_LIBS vulkan_graph_lib vulkan_graph_shaderlib)
62-
63-
# vulkan_graph_shaderlib
60+
# vulkan_standard_shaders
6461

6562
set(VULKAN_GRAPH_SHADERS_PATH ${RUNTIME_PATH}/graph/ops/glsl/)
66-
vulkan_shader_library(${VULKAN_GRAPH_SHADERS_PATH} vulkan_graph_shaderlib)
63+
# Declares a vulkan_standard_shaders library containing compiled SPIR-V shaders
64+
# from the above path
65+
vulkan_shader_library(${VULKAN_GRAPH_SHADERS_PATH} vulkan_standard_shaders)
66+
# preserve the path of the generated cpp file
67+
set(vulkan_standard_shaders_cpp
68+
${CMAKE_BINARY_DIR}/vulkan_standard_shaders/spv.cpp)
6769

6870
# Generate Files from flatc
6971

@@ -90,15 +92,16 @@ target_include_directories(
9092
vulkan_schema INTERFACE ${SCHEMA_INCLUDE_DIR}
9193
${EXECUTORCH_ROOT}/third-party/flatbuffers/include)
9294

93-
# vulkan_backend_lib
95+
# vulkan_backend
9496

9597
file(GLOB vulkan_backend_cpp ${RUNTIME_PATH}/*.cpp)
98+
list(APPEND vulkan_backend_cpp ${vulkan_graph_cpp})
99+
list(APPEND vulkan_backend_cpp ${vulkan_standard_shaders_cpp})
96100

97-
add_library(vulkan_backend ${vulkan_backend_cpp})
101+
add_library(vulkan_backend STATIC ${vulkan_backend_cpp})
98102
target_include_directories(vulkan_backend PRIVATE ${SCHEMA_INCLUDE_DIR}
99103
${COMMON_INCLUDES})
100-
target_link_libraries(vulkan_backend PRIVATE vulkan_graph_lib vulkan_schema
101-
executorch)
104+
target_link_libraries(vulkan_backend PRIVATE vulkan_schema executorch)
102105
target_compile_options(vulkan_backend PRIVATE ${VULKAN_CXX_FLAGS})
103106
# Link this library with --whole-archive due to dynamic backend registration
104107
target_link_options_shared_lib(vulkan_backend)
@@ -110,15 +113,13 @@ if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*iOS\.cmake$")
110113
list(TRANSFORM VULKAN_RUNNER_SRCS PREPEND "${EXECUTORCH_ROOT}/")
111114

112115
add_executable(vulkan_executor_runner ${VULKAN_RUNNER_SRCS})
113-
target_link_libraries(
114-
vulkan_executor_runner ${_executor_runner_libs} vulkan_schema
115-
vulkan_backend ${VULKAN_STANDARD_OPS_LIBS})
116+
target_link_libraries(vulkan_executor_runner ${_executor_runner_libs}
117+
vulkan_schema vulkan_backend)
116118
target_compile_options(vulkan_executor_runner PUBLIC ${VULKAN_CXX_FLAGS})
117119

118120
add_library(vulkan_executor_runner_lib STATIC ${VULKAN_RUNNER_SRCS})
119-
target_link_libraries(
120-
vulkan_executor_runner_lib ${_executor_runner_libs} vulkan_schema
121-
vulkan_backend vulkan_api_lib ${VULKAN_STANDARD_OPS_LIBS})
121+
target_link_libraries(vulkan_executor_runner_lib ${_executor_runner_libs}
122+
vulkan_schema vulkan_backend)
122123
target_compile_options(vulkan_executor_runner_lib PUBLIC ${VULKAN_CXX_FLAGS})
123124
endif()
124125

@@ -140,7 +141,13 @@ if(EXECUTORCH_BUILD_GTESTS)
140141
target_include_directories(vulkan_compute_api_test
141142
PRIVATE ${COMMON_INCLUDES} ${TEST_UTILS_HEADERS})
142143
target_link_libraries(
143-
vulkan_compute_api_test PRIVATE gtest_main ${VULKAN_STANDARD_OPS_LIBS}
144-
test_shaderlib)
144+
vulkan_compute_api_test PRIVATE gtest_main vulkan_graph_lib
145+
vulkan_standard_shaders test_shaderlib)
145146
target_compile_options(vulkan_compute_api_test PRIVATE ${VULKAN_CXX_FLAGS})
146147
endif()
148+
149+
install(
150+
TARGETS vulkan_backend
151+
DESTINATION lib
152+
INCLUDES
153+
DESTINATION ${COMMON_INCLUDES})

backends/vulkan/cmake/ATenVulkan.cmake

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ if(NOT VULKAN_THIRD_PARTY_PATH)
2222
set(VULKAN_THIRD_PARTY_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../third-party)
2323
endif()
2424

25-
# Source paths and compile settings
25+
# Source file paths
2626

2727
set(ATEN_PATH ${PYTORCH_PATH}/aten/src)
2828
set(ATEN_VULKAN_PATH ${ATEN_PATH}/ATen/native/vulkan)
@@ -31,28 +31,26 @@ set(VULKAN_HEADERS_PATH ${VULKAN_THIRD_PARTY_PATH}/Vulkan-Headers/include)
3131
set(VOLK_PATH ${VULKAN_THIRD_PARTY_PATH}/volk)
3232
set(VMA_PATH ${VULKAN_THIRD_PARTY_PATH}/VulkanMemoryAllocator)
3333

34+
# Compile settings
35+
3436
set(VULKAN_CXX_FLAGS "")
3537
list(APPEND VULKAN_CXX_FLAGS "-DUSE_VULKAN_API")
3638
list(APPEND VULKAN_CXX_FLAGS "-DUSE_VULKAN_WRAPPER")
3739
list(APPEND VULKAN_CXX_FLAGS "-DUSE_VULKAN_VOLK")
3840
list(APPEND VULKAN_CXX_FLAGS "-DVK_NO_PROTOTYPES")
3941
list(APPEND VULKAN_CXX_FLAGS "-DVOLK_DEFAULT_VISIBILITY")
4042

41-
# vulkan_api_lib
42-
43-
file(GLOB VULKAN_API_CPP ${ATEN_VULKAN_PATH}/api/*.cpp)
43+
# vulkan_api source files
4444

45-
add_library(vulkan_api_lib STATIC ${VULKAN_API_CPP} ${VOLK_PATH}/volk.c)
45+
file(GLOB vulkan_api_cpp ${ATEN_VULKAN_PATH}/api/*.cpp)
46+
list(APPEND vulkan_api_cpp ${VOLK_PATH}/volk.c)
4647

4748
set(VULKAN_API_HEADERS)
4849
list(APPEND VULKAN_API_HEADERS ${ATEN_PATH})
4950
list(APPEND VULKAN_API_HEADERS ${VULKAN_HEADERS_PATH})
5051
list(APPEND VULKAN_API_HEADERS ${VOLK_PATH})
5152
list(APPEND VULKAN_API_HEADERS ${VMA_PATH})
5253

53-
target_include_directories(vulkan_api_lib PRIVATE ${VULKAN_API_HEADERS})
54-
target_compile_options(vulkan_api_lib PRIVATE ${VULKAN_CXX_FLAGS})
55-
5654
# Find GLSL compiler executable
5755

5856
if(ANDROID)
@@ -65,9 +63,8 @@ if(ANDROID)
6563
else()
6664
find_program(
6765
GLSLC_PATH glslc
68-
PATHS ENV VULKAN_SDK
69-
PATHS "$ENV{VULKAN_SDK}/${CMAKE_HOST_SYSTEM_PROCESSOR}/bin"
70-
PATHS "$ENV{VULKAN_SDK}/bin")
66+
PATHS ENV VULKAN_SDK "$ENV{VULKAN_SDK}/${CMAKE_HOST_SYSTEM_PROCESSOR}/bin"
67+
"$ENV{VULKAN_SDK}/bin")
7168

7269
if(NOT GLSLC_PATH)
7370
message(FATAL_ERROR "USE_VULKAN glslc not found")
@@ -77,29 +74,31 @@ endif()
7774
# Required to enable linking with --whole-archive
7875
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
7976

80-
# Convenience macro to create a shader library
81-
82-
macro(vulkan_shader_library SHADERS_PATH LIBRARY_NAME)
77+
# Convenience macro to generate a SPIR-V shader library target. Given the path
78+
# to the shaders to compile and the name of the library, it will create a static
79+
# library containing the generated SPIR-V shaders. The generated_spv_cpp
80+
# variable can be used to reference the generated CPP file outside the macro.
81+
macro(VULKAN_SHADER_LIBRARY shaders_path library_name)
8382
set(VULKAN_SHADERGEN_ENV "")
84-
set(VULKAN_SHADERGEN_OUT_PATH ${CMAKE_BINARY_DIR}/${LIBRARY_NAME})
83+
set(VULKAN_SHADERGEN_OUT_PATH ${CMAKE_BINARY_DIR}/${library_name})
8584

8685
execute_process(
8786
COMMAND
8887
"${PYTHON_EXECUTABLE}" ${PYTORCH_PATH}/tools/gen_vulkan_spv.py --glsl-path
89-
${SHADERS_PATH} --output-path ${VULKAN_SHADERGEN_OUT_PATH}
88+
${shaders_path} --output-path ${VULKAN_SHADERGEN_OUT_PATH}
9089
--glslc-path=${GLSLC_PATH} --tmp-dir-path=${VULKAN_SHADERGEN_OUT_PATH}
9190
--env ${VULKAN_GEN_ARG_ENV}
9291
RESULT_VARIABLE error_code)
9392
set(ENV{PYTHONPATH} ${PYTHONPATH})
9493

95-
set(vulkan_generated_cpp ${VULKAN_SHADERGEN_OUT_PATH}/spv.cpp)
94+
set(generated_spv_cpp ${VULKAN_SHADERGEN_OUT_PATH}/spv.cpp)
9695

97-
add_library(${LIBRARY_NAME} STATIC ${vulkan_generated_cpp})
98-
target_include_directories(${LIBRARY_NAME} PRIVATE ${COMMON_INCLUDES})
99-
target_link_libraries(${LIBRARY_NAME} vulkan_api_lib)
100-
target_compile_options(${LIBRARY_NAME} PRIVATE ${VULKAN_CXX_FLAGS})
96+
add_library(${library_name} STATIC ${generated_spv_cpp})
97+
target_include_directories(${library_name} PRIVATE ${COMMON_INCLUDES})
98+
target_link_libraries(${library_name} vulkan_graph_lib)
99+
target_compile_options(${library_name} PRIVATE ${VULKAN_CXX_FLAGS})
101100
# Link this library with --whole-archive due to dynamic shader registrations
102-
target_link_options_shared_lib(${LIBRARY_NAME})
101+
target_link_options_shared_lib(${library_name})
103102

104103
unset(VULKAN_SHADERGEN_ENV)
105104
unset(VULKAN_SHADERGEN_OUT_PATH)

build/executorch-config.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ endif()
4444
set(lib_list
4545
etdump bundled_program extension_data_loader ${FLATCC_LIB} mpsdelegate
4646
qnn_executorch_backend portable_ops_lib extension_module xnnpack_backend
47-
XNNPACK cpuinfo pthreadpool
47+
XNNPACK cpuinfo pthreadpool vulkan_backend
4848
)
4949
foreach(lib ${lib_list})
5050
# Name of the variable which stores result of the find_library search

examples/models/llama2/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ if(TARGET xnnpack_backend)
6969
$<TARGET_FILE:xnnpack_backend> \
7070
LINKER:--no-whole-archive")
7171
endif()
72+
73+
# Vulkan backend
74+
if(TARGET vulkan_backend)
75+
target_link_libraries(llama_main PUBLIC vulkan_backend)
76+
target_link_options(
77+
llama_main PUBLIC "SHELL:LINKER:--whole-archive \
78+
$<TARGET_FILE:vulkan_backend> \
79+
LINKER:--no-whole-archive")
80+
endif()
81+
7282
target_compile_options(llama_main PUBLIC ${_common_compile_options})
7383

7484
# Print all summary

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
577577

578578
if args.vulkan:
579579
assert (
580-
args.dtype_override is None
580+
args.dtype_override == "fp32" or args.dtype_override is None
581581
), "Vulkan backend does not support non fp32 dtypes at the moment"
582582
assert (
583583
args.quantization_mode is None

0 commit comments

Comments
 (0)