Skip to content

Commit 75284d2

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add template based unboxing (#1284)
Summary: Pull Request resolved: #1284 Adding a new feature to allow users to bypass codegen and register their kernels directly. This is very useful for custom kernels for custom ops. Example usage: ``` Tensor& my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor& out) { // ... return out; } Kernel my_kernel = Kernel.make_boxed_kernel("my_ns::my_op",EXECUTORCH_FN(my_op)); register_kernels({my_kernel}); ``` imported-using-ghimport Test Plan: Imported from OSS Reviewed By: iseeyuan Differential Revision: D51553099 Pulled By: larryliu0820 fbshipit-source-id: 0e3877a481d58c4e64c7767f7693537407ab27c5
1 parent 4dfb637 commit 75284d2

File tree

9 files changed

+557
-0
lines changed

9 files changed

+557
-0
lines changed

extension/kernel_util/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
This header file `make_boxed_from_unboxed_functor.h` defines a template that can be used to create a boxed version of an unboxed functor. It is part of the executorch extension in the torch namespace.
2+
## Requirements
3+
This header requires C++17 or later.
4+
## Usage
5+
The template takes an unboxed function pointer and wraps it into a functor that takes `RuntimeContext` and `EValues` as inputs and returns void. The wrapped functor will unbox all inputs and forward them to the unboxed kernel.
6+
Here is an example of how to use the template:
7+
```C++
8+
Tensor& my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor& out) {
9+
// ...
10+
return out;
11+
}
12+
Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op", EXECUTORCH_FN(my_op));
13+
static auto res = register_kernels({my_kernel});
14+
```
15+
Alternatively, you can use the EXECUTORCH_LIBRARY macro to simplify the process:
16+
```C++
17+
EXECUTORCH_LIBRARY(my_ns, "my_op", my_op);
18+
```
19+
## Details
20+
The template uses a lot of C++17 features to convert each EValue to the inferred argument type. It checks if the first argument is `RuntimeContext`, and if so, it removes it. The call method of the `WrapUnboxedIntoFunctor` struct calls the unboxed function with the corresponding arguments.
21+
The `EXECUTORCH_LIBRARY` macro registers the kernel for the operation and stores the result in a static variable.
22+
## Note
23+
The `RuntimeContext` is a placeholder for a context that will be passed to kernels. It is currently empty, but it is planned to be used for kernel temp memory allocation and error handling in the future.

extension/kernel_util/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
//===----------------------------------------------------------------------===//
10+
/// \file extension/kernel_util/make_boxed_from_unboxed_functor.h
11+
/// Defines a template that can be used to create a boxed version of an unboxed
12+
/// functor.
13+
/// Example usage:
14+
/// ```
15+
/// Tensor&
16+
/// my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor&
17+
/// out) {
18+
/// // ...
19+
/// return out;
20+
/// }
21+
///
22+
/// Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op",
23+
/// EXECUTORCH_FN(my_op));
24+
/// static auto res = register_kernels({my_kernel});
25+
/// ```
26+
/// Or simply:
27+
/// ```
28+
/// EXECUTORCH_LIBRARY(my_ns, "my_op", my_op);
29+
/// ```
30+
///
31+
/// The trick here is to convert each EValue to inferred argument type. This
32+
/// uses a lot of C++17 features.
33+
//===----------------------------------------------------------------------===//
34+
35+
#pragma once
36+
#if __cplusplus < 201703L
37+
#error "This header requires C++17"
38+
#endif
39+
40+
#include <executorch/extension/kernel_util/meta_programming.h>
41+
#include <executorch/extension/kernel_util/type_list.h>
42+
#include <executorch/runtime/core/evalue.h>
43+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
44+
#include <executorch/runtime/kernel/operator_registry.h>
45+
#include <cstdlib>
46+
#include <memory>
47+
#include <type_traits>
48+
#include <typeinfo>
49+
50+
namespace torch {
51+
namespace executor {
52+
53+
class KernelRuntimeContext; // Forward declaration
54+
using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove
55+
56+
// evalue_to_arg
57+
template <class T>
58+
struct decay_if_not_tensor final {
59+
using type = std::decay_t<T>;
60+
};
61+
template <>
62+
struct decay_if_not_tensor<exec_aten::Tensor&> final {
63+
using type = exec_aten::Tensor&;
64+
};
65+
template <>
66+
struct decay_if_not_tensor<const exec_aten::Tensor&> final {
67+
using type = const exec_aten::Tensor&;
68+
};
69+
70+
template <class T>
71+
struct evalue_to_arg final {
72+
static T call(EValue& v) {
73+
return std::move(v).to<T>();
74+
}
75+
};
76+
77+
template <>
78+
struct evalue_to_arg<exec_aten::Tensor&> final {
79+
static exec_aten::Tensor& call(EValue& v) {
80+
return v.toTensor();
81+
}
82+
};
83+
84+
template <>
85+
struct evalue_to_arg<const exec_aten::Tensor&> final {
86+
static const exec_aten::Tensor& call(EValue& v) {
87+
return v.toTensor();
88+
}
89+
};
90+
// Call functor with args from stack
91+
92+
template <class Functor, size_t... evalue_arg_indices, typename... ArgTypes>
93+
void call_functor_with_args_from_stack_(
94+
RuntimeContext& ctx,
95+
EValue** stack,
96+
std::index_sequence<evalue_arg_indices...>,
97+
typelist<ArgTypes...>*) {
98+
(*Functor::func_ptr())(
99+
ctx,
100+
evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call(
101+
*stack[evalue_arg_indices])...);
102+
}
103+
104+
/**
105+
* WrapUnboxedIntoFunctor: Given a function pointer, wrap it into a functor that
106+
* takes EValues as input and returns void. The wrapped functor will unbox all
107+
* inputs and forward them to unboxed kernel.
108+
*/
109+
template <class FuncType>
110+
struct WrapUnboxedIntoFunctor {
111+
static_assert(
112+
is_compile_time_function_pointer<FuncType>::value,
113+
"Can't handle function other than EXECUTORCH_FN");
114+
using TrueType = typename FuncType::FuncType;
115+
using ReturnType = typename infer_function_traits_t<TrueType>::return_type;
116+
using ArgsType = typename infer_function_traits_t<TrueType>::parameter_types;
117+
// check if the first argument is RuntimeContext, if so, remove it
118+
static constexpr bool first_arg_is_context = std::is_same<
119+
RuntimeContext,
120+
std::remove_reference_t<head_with_default_t<void, ArgsType>>>::value;
121+
using ContextRemovedArgsType = std::conditional_t<
122+
first_arg_is_context,
123+
drop_if_nonempty_t<ArgsType, 1>,
124+
ArgsType>;
125+
126+
static void call(RuntimeContext& ctx, EValue** stack) {
127+
constexpr size_t num_inputs = size<ContextRemovedArgsType>::value;
128+
return call_functor_with_args_from_stack_<FuncType>(
129+
ctx,
130+
stack,
131+
std::make_index_sequence<num_inputs>(),
132+
static_cast<ContextRemovedArgsType*>(nullptr));
133+
}
134+
};
135+
136+
template <typename FuncType>
137+
static Kernel make_boxed_kernel(const char* name, FuncType) {
138+
return Kernel(name, WrapUnboxedIntoFunctor<FuncType>::call);
139+
}
140+
141+
#define EXECUTORCH_LIBRARY(ns, op_name, func) \
142+
static auto res_##ns = register_kernels( \
143+
make_boxed_kernel(#ns "::" op_name, EXECUTORCH_FN(func)))
144+
} // namespace executor
145+
} // namespace torch
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#if __cplusplus < 201703L
11+
#error "This header requires C++17"
12+
#endif
13+
14+
#include <executorch/extension/kernel_util/type_list.h>
15+
#include <cstdlib>
16+
#include <memory>
17+
#include <type_traits>
18+
#include <typeinfo>
19+
20+
namespace torch {
21+
namespace executor {
22+
23+
// Check if a given type is a function
24+
template <class T>
25+
struct is_function_type : std::false_type {};
26+
template <class Result, class... Args>
27+
struct is_function_type<Result(Args...)> : std::true_type {};
28+
template <class T>
29+
using is_function_type_t = typename is_function_type<T>::type;
30+
31+
// A compile-time wrapper around a function pointer
32+
template <class FuncType_, FuncType_* func_ptr_>
33+
struct CompileTimeFunctionPointer final {
34+
static_assert(
35+
is_function_type<FuncType_>::value,
36+
"EXECUTORCH_FN can only wrap function types.");
37+
using FuncType = FuncType_;
38+
39+
static constexpr FuncType* func_ptr() {
40+
return func_ptr_;
41+
}
42+
};
43+
44+
// Check if a given type is a compile-time function pointer
45+
template <class T>
46+
struct is_compile_time_function_pointer : std::false_type {};
47+
template <class FuncType, FuncType* func_ptr>
48+
struct is_compile_time_function_pointer<
49+
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
50+
51+
#define EXECUTORCH_FN_TYPE(func) \
52+
CompileTimeFunctionPointer< \
53+
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
54+
func>
55+
#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)()
56+
57+
/**
58+
* strip_class: helper to remove the class type from pointers to `operator()`.
59+
*/
60+
template <typename T>
61+
struct strip_class {};
62+
template <typename Class, typename Result, typename... Args>
63+
struct strip_class<Result (Class::*)(Args...)> {
64+
using type = Result(Args...);
65+
};
66+
template <typename Class, typename Result, typename... Args>
67+
struct strip_class<Result (Class::*)(Args...) const> {
68+
using type = Result(Args...);
69+
};
70+
template <typename T>
71+
using strip_class_t = typename strip_class<T>::type;
72+
73+
/**
74+
* Access information about result type or arguments from a function type.
75+
* Example:
76+
* using A = function_traits<int (float, double)>::return_type // A == int
77+
* using A = function_traits<int (float, double)>::parameter_types::tuple_type
78+
* // A == tuple<float, double>
79+
*/
80+
template <class Func>
81+
struct function_traits {
82+
static_assert(
83+
!std::is_same<Func, Func>::value,
84+
"In function_traits<Func>, Func must be a plain function type.");
85+
};
86+
template <class Result, class... Args>
87+
struct function_traits<Result(Args...)> {
88+
using func_type = Result(Args...);
89+
using return_type = Result;
90+
using parameter_types = typelist<Args...>;
91+
static constexpr auto number_of_parameters = sizeof...(Args);
92+
};
93+
94+
/**
95+
* infer_function_traits: creates a `function_traits` type for a simple
96+
* function (pointer) or functor (lambda/struct). Currently does not support
97+
* class methods.
98+
*/
99+
template <typename Functor>
100+
struct infer_function_traits {
101+
using type = function_traits<strip_class_t<decltype(&Functor::operator())>>;
102+
};
103+
template <typename Result, typename... Args>
104+
struct infer_function_traits<Result (*)(Args...)> {
105+
using type = function_traits<Result(Args...)>;
106+
};
107+
template <typename Result, typename... Args>
108+
struct infer_function_traits<Result(Args...)> {
109+
using type = function_traits<Result(Args...)>;
110+
};
111+
template <typename T>
112+
using infer_function_traits_t = typename infer_function_traits<T>::type;
113+
114+
} // namespace executor
115+
} // namespace torch

extension/kernel_util/targets.bzl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
runtime.cxx_library(
11+
name = "kernel_util",
12+
srcs = [],
13+
exported_headers = [
14+
"make_boxed_from_unboxed_functor.h",
15+
"meta_programming.h",
16+
"type_list.h",
17+
],
18+
visibility = [
19+
"//executorch/...",
20+
"@EXECUTORCH_CLIENTS",
21+
],
22+
exported_deps = [
23+
"//executorch/runtime/core:core",
24+
"//executorch/runtime/core:evalue",
25+
"//executorch/runtime/kernel:kernel_includes",
26+
"//executorch/runtime/kernel:kernel_runtime_context",
27+
"//executorch/runtime/kernel:operator_registry",
28+
],
29+
)

extension/kernel_util/test/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()

0 commit comments

Comments
 (0)