Skip to content

Commit 51de4a0

Browse files
committed
[NOT FOR LAND] Prototype to register llama.cpp kernels into ExecuTorch
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 3ccfe0c commit 51de4a0

File tree

6 files changed

+382
-35
lines changed

6 files changed

+382
-35
lines changed

examples/llama_cpp/CMakeLists.txt

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,59 +17,82 @@
1717
#
1818

1919
cmake_minimum_required(VERSION 3.19)
20+
project(LlamaCppExample)
21+
2022
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
2123
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
2224
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
2325
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
2426

27+
set(_common_compile_options -Wno-deprecated-declarations -fPIC)
28+
29+
# Let files say "include <executorch/path/to/header.h>".
30+
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
31+
32+
find_package(Llama REQUIRED)
33+
find_package(ExecuTorch REQUIRED)
34+
find_package(
35+
gflags REQUIRED PATHS ${CMAKE_CURRENT_BINARY_DIR}/../../third-party
36+
)
37+
38+
target_include_directories(executorch INTERFACE ${_common_include_directories})
39+
2540
#
26-
# select_build_lib: C++ library to register selected ops in custom kernel
27-
# library
41+
# The `_<target>_srcs` lists are defined by including ${EXECUTORCH_SRCS_FILE}.
2842
#
29-
set(_kernel_lib)
30-
if(EXECUTORCH_SELECT_OPS_YAML)
31-
set(_custom_ops_yaml
32-
${EXECUTORCH_ROOT}/examples/portable/custom_ops/custom_ops.yaml)
33-
gen_selected_ops("${_custom_ops_yaml}" "" "")
34-
set(kernel_sources
35-
${EXECUTORCH_ROOT}/examples/portable/custom_ops/custom_ops_1_out.cpp
36-
${EXECUTORCH_ROOT}/examples/portable/custom_ops/custom_ops_2_out.cpp)
37-
#
38-
# custom_kernels: C++ kernel implementations of custom ops
43+
set(
44+
EXECUTORCH_SRCS_FILE
45+
"${CMAKE_CURRENT_BINARY_DIR}/../../executorch_srcs.cmake"
46+
)
47+
if(NOT EXISTS "${EXECUTORCH_SRCS_FILE}")
48+
# A file wasn't generated. Run a script to extract the source lists from the
49+
# buck2 build system and write them to a file we can include.
3950
#
40-
add_library(custom_kernels ${kernel_sources})
41-
target_link_libraries(custom_kernels PRIVATE executorch)
42-
target_compile_options(custom_kernels PUBLIC ${_common_compile_options})
43-
44-
list(APPEND _kernel_lib custom_kernels)
45-
else()
46-
list(APPEND _kernel_lib portable_kernels)
51+
# NOTE: This will only happen once during cmake setup, so it will not re-run
52+
# if the buck2 targets change.
53+
message(STATUS "executorch: Generating source lists")
54+
set(EXECUTORCH_SRCS_FILE "${CMAKE_CURRENT_BINARY_DIR}/executorch_srcs.cmake")
55+
extract_sources(${EXECUTORCH_SRCS_FILE})
4756
endif()
4857

49-
gen_selected_ops(
50-
"${_custom_ops_yaml}"
51-
"${EXECUTORCH_SELECT_OPS_LIST}"
52-
"${EXECUTORCH_SELECT_ALL_OPS}")
58+
# This file defines the `_<target>__srcs` variables used below.
59+
message(STATUS "executorch: Using sources file ${EXECUTORCH_SRCS_FILE}")
60+
include(${EXECUTORCH_SRCS_FILE})
61+
62+
set(_custom_ops_yaml ${EXECUTORCH_ROOT}/examples/llama_cpp/custom_ops.yaml)
63+
set(_ops_yaml ${EXECUTORCH_ROOT}/kernels/portable/functions.yaml)
5364

54-
generate_bindings_for_kernels(${EXECUTORCH_ROOT}/kernels/portable/functions.yaml
55-
"${_custom_ops_yaml}")
56-
gen_operators_lib("select_build_lib" ${_kernel_lib} executorch)
65+
set(kernel_sources ${EXECUTORCH_ROOT}/examples/llama_cpp/op_mm.cpp)
66+
#
67+
# custom_kernels: C++ kernel implementations of custom ops
68+
#
69+
add_library(custom_kernels ${kernel_sources})
70+
target_link_libraries(custom_kernels PRIVATE executorch llama)
71+
target_compile_options(custom_kernels PUBLIC ${_common_compile_options})
72+
73+
set(_kernel_lib custom_kernels portable_kernels)
74+
75+
# Select all ops in functions.yaml as well as custom op.
76+
gen_selected_ops("${_ops_yaml}" "ggml::mul_mat.out" "")
77+
78+
#
79+
# kernel_lib: contains both custom_kernels and portable_kernels
80+
#
81+
generate_bindings_for_kernels("${_ops_yaml}" "${_custom_ops_yaml}")
82+
gen_operators_lib("kernel_lib" ${_kernel_lib} executorch)
83+
target_link_libraries(kernel_lib PRIVATE executorch)
5784

58-
set(_updated__srcs)
59-
foreach(_src ${_executor_runner__srcs})
60-
list(APPEND _updated__srcs "${EXECUTORCH_ROOT}/${_src}")
61-
endforeach()
85+
list(TRANSFORM _executor_runner__srcs PREPEND "${EXECUTORCH_ROOT}/")
6286

6387
#
64-
# selective_build_test: test binary to allow different operator libraries to
65-
# link to
88+
# llama_cpp_test: test binary to run llama.cpp kernel ggml_mul_mat
6689
#
67-
add_executable(selective_build_test ${_updated__srcs})
90+
add_executable(llama_cpp_test ${_executor_runner__srcs})
6891
if(CMAKE_BUILD_TYPE EQUAL "RELEASE")
6992
target_link_options(selective_build_test PRIVATE "LINKER:--gc-sections")
7093
endif()
71-
target_link_libraries(selective_build_test executorch gflags select_build_lib)
72-
target_compile_options(selective_build_test PUBLIC ${_common_compile_options})
94+
target_link_libraries(llama_cpp_test executorch gflags kernel_lib)
95+
target_compile_options(llama_cpp_test PUBLIC ${_common_compile_options})
7396

7497
# Print all summary
7598
executorch_print_configuration_summary()

examples/llama_cpp/custom_ops.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- func: ggml::mul_mat.out(Tensor in, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
2+
kernels:
3+
- arg_meta: null
4+
kernel_name: llama_cpp::mm_out

examples/llama_cpp/export.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Example script for exporting simple models to flatbuffer
8+
9+
import argparse
10+
import logging
11+
12+
from ..models import MODEL_NAME_TO_MODEL
13+
from ..models.model_factory import EagerModelFactory
14+
from ..portable.utils import export_to_edge, save_pte_program
15+
from .permute_mm_fusion_pass import PermuteMMFusionPass
16+
from torch._export import capture_pre_autograd_graph
17+
18+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
19+
logging.basicConfig(level=logging.INFO, format=FORMAT)
20+
21+
22+
if __name__ == "__main__":
23+
24+
model, example_inputs = EagerModelFactory.create_model(
25+
*MODEL_NAME_TO_MODEL["llama2"]
26+
)
27+
m = model.eval()
28+
# pre-autograd export. eventually this will become torch.export
29+
m = capture_pre_autograd_graph(m, example_inputs)
30+
31+
edge_ir = export_to_edge(m, example_inputs).transform([PermuteMMFusionPass(_fix_node_meta_val=True)])
32+
print(f"Exported graph:\n{edge_ir.exported_program().graph}")
33+
34+
prog = edge_ir.to_executorch()
35+
36+
save_pte_program(prog.buffer, "llama2_fused")

examples/llama_cpp/op_mm.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include "ggml.h"
9+
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace llama_cpp {
13+
namespace native {
14+
15+
using Tensor = exec_aten::Tensor;
16+
using RuntimeContext = exec_aten::RuntimeContext;
17+
using Error = torch::executor::Error;
18+
19+
// Helper function to create a ggml tensor with preallocated memory
20+
static struct ggml_tensor * ggml_tensor_from(const Tensor & t, const int64_t * ne_override) {
21+
// HACK: since this is only used by mm, hardcode n_dims to 2
22+
// Should be t.dim() but that requires refactoring
23+
int n_dims = 2;
24+
// ET_CHECK_MSG(n_dims >= 1 && n_dims <= GGML_MAX_DIMS, "dimension %d is not within range (1, %d)", n_dims, GGML_MAX_DIMS);
25+
26+
void * data = t.mutable_data_ptr();
27+
28+
// TODO use memory from context to create tensor
29+
struct ggml_tensor * const result = (struct ggml_tensor *) malloc(sizeof (struct ggml_tensor));
30+
31+
ET_CHECK_MSG(t.scalar_type() == exec_aten::ScalarType::Float, "only float type supported");
32+
// TODO support different types
33+
enum ggml_type type = ggml_type::GGML_TYPE_F32;
34+
*result = (struct ggml_tensor) {
35+
/*.type =*/ type,
36+
/*.backend =*/ GGML_BACKEND_CPU,
37+
/*.buffer =*/ NULL,
38+
/*.n_dims =*/ n_dims,
39+
/*.ne =*/ { 1, 1, 1, 1 },
40+
/*.nb =*/ { 0, 0, 0, 0 },
41+
/*.op =*/ GGML_OP_NONE,
42+
/*.op_params =*/ { 0 },
43+
/*.is_param =*/ false,
44+
/*.grad =*/ NULL,
45+
/*.src =*/ { NULL },
46+
/*.perf_runs =*/ 0,
47+
/*.perf_cycles =*/ 0,
48+
/*.perf_time_us =*/ 0,
49+
/*.view_src =*/ NULL,
50+
/*.view_offs =*/ 0,
51+
/*.data =*/ data,
52+
/*.name =*/ { 0 },
53+
/*.extra =*/ NULL,
54+
/*.padding =*/ { 0 },
55+
};
56+
57+
// TODO: this should not be needed as long as we don't rely on aligned SIMD loads
58+
//ggml_assert_aligned(result->data);
59+
60+
if (ne_override != NULL) {
61+
for (int i = 0; i < n_dims; i++) {
62+
result->ne[i] = ne_override[i];
63+
}
64+
} else {
65+
for (int i = 0; i < n_dims; i++) {
66+
result->ne[i] = t.sizes()[i];
67+
}
68+
}
69+
70+
result->nb[0] = ggml_type_size(type);
71+
result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));
72+
for (int i = 2; i < GGML_MAX_DIMS; i++) {
73+
result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
74+
}
75+
76+
// ctx->n_objects++;
77+
78+
return result;
79+
}
80+
81+
// View(mat2, {1, 64}), transpose, then matmul.
82+
Tensor&
83+
mm_out(RuntimeContext& ctx, const Tensor& in, const Tensor& mat2, Tensor& out) {
84+
85+
// prepare input tensors
86+
// HACK: view(mat2, {64, 1});
87+
const int64_t dims[4] = {64, 1, 1, 1};
88+
89+
struct ggml_tensor * a = ggml_tensor_from(in, NULL);
90+
91+
struct ggml_tensor * b = ggml_tensor_from(mat2, dims);
92+
93+
// GGML_ASSERT(ggml_can_mul_mat(b, a));
94+
// GGML_ASSERT(!ggml_is_transposed(b));
95+
96+
const int64_t ne[4] = { b->ne[1], a->ne[1], a->ne[2], a->ne[3] };
97+
struct ggml_tensor * result = ggml_tensor_from(out, ne);
98+
99+
result->op = GGML_OP_MUL_MAT;
100+
result->grad = NULL;
101+
result->src[0] = b;
102+
result->src[1] = a;
103+
104+
// run op
105+
struct ggml_cgraph gf = ggml_build_forward(result);
106+
107+
struct ggml_cplan plan = ggml_graph_plan(&gf, /*int n_threads*/1);
108+
int res = ggml_graph_compute(&gf, &plan);
109+
110+
return out;
111+
}
112+
113+
} // namespace native
114+
} // namespace llama_cpp
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Callable, List, Tuple
7+
8+
import torch
9+
from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
from executorch.exir.passes.replace_aten_with_edge_pass import (
13+
aten_to_edge,
14+
should_lower_to_edge,
15+
)
16+
from torch import fx
17+
from torch.fx import GraphModule, subgraph_rewriter
18+
from torch.fx.passes.infra.pass_base import PassResult
19+
from torch.utils import _pytree as pytree
20+
21+
from torch.library import impl, Library
22+
23+
custom_ops_lib = Library("ggml", "DEF")
24+
25+
custom_ops_lib.define(
26+
"mul_mat.out(Tensor input, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"
27+
)
28+
29+
custom_ops_lib.define("mul_mat(Tensor input, Tensor mat2) -> Tensor")
30+
31+
32+
def out_kernel(a, b, *, out):
33+
d = torch.ops.aten.view_copy.default(b, [1, 64])
34+
e = torch.ops.aten.mm.out(d, a, out=out)
35+
return out
36+
37+
38+
custom_ops_lib.impl("mul_mat.out", out_kernel)
39+
40+
41+
def _trace_and_lower_to_edge_ops(f: Callable) -> fx.GraphModule:
42+
gm = fx.symbolic_trace(f)
43+
for node in gm.graph.nodes:
44+
if node.op == "call_function" and should_lower_to_edge(node.target):
45+
node.target = aten_to_edge(node.target)
46+
gm.recompile()
47+
return gm
48+
49+
50+
# Fuse the following pattern:
51+
# - d = view_copy(b, [1, 64])
52+
# - e = mm(d, a)
53+
54+
55+
def get_patterns_and_replacements() -> List[Tuple[Callable, Callable, List[Callable]]]:
56+
@bind_pattern_to_op(custom_ops_lib, "mul_mat")
57+
def pattern(a, b):
58+
d = torch.ops.aten.view_copy.default(b, [1, 64])
59+
e = torch.ops.aten.mm.default(d, a)
60+
return e
61+
62+
def replacement(a, b):
63+
return torch.ops.ggml.mul_mat.default(a, b)
64+
65+
p_graph = _trace_and_lower_to_edge_ops(pattern)
66+
r_graph = _trace_and_lower_to_edge_ops(replacement)
67+
# print(p_graph.graph)
68+
# print(r_graph.graph)
69+
return [
70+
(
71+
p_graph,
72+
r_graph,
73+
[],
74+
)
75+
]
76+
77+
78+
class PermuteMMFusionPass(ExportPass):
79+
def __init__(self, _fix_node_meta_val=False):
80+
super().__init__()
81+
self._fix_node_meta_val = _fix_node_meta_val
82+
83+
def call(self, graph_module: GraphModule) -> PassResult:
84+
for (
85+
pattern,
86+
replacement,
87+
match_filters,
88+
) in get_patterns_and_replacements():
89+
subgraph_rewriter.replace_pattern_with_filters(
90+
graph_module, pattern, replacement, match_filters
91+
)
92+
93+
if self._fix_node_meta_val:
94+
for n in graph_module.graph.nodes:
95+
if n.op == "call_function" and "val" not in n.meta:
96+
args, kwargs = pytree.tree_map_only(
97+
torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs)
98+
)
99+
n.meta["val"] = n.target(*args, **kwargs)
100+
graph_module.graph.lint()
101+
graph_module.graph.eliminate_dead_code()
102+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)