Skip to content

Commit cf57f5f

Browse files
committed
[kernel] Add template based unboxing
Adding a new feature to allow users to bypass codegen and register their kernels directly. This is very useful for custom kernels for custom ops. Example usage: ``` Tensor& my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor& out) { // ... return out; } Kernel my_kernel = Kernel.make_boxed_kernel("my_ns::my_op",EXECUTORCH_FN(my_op)); register_kernels({my_kernel}); ``` ghstack-source-id: c430da1 Pull Request resolved: #1284
1 parent 36e03ce commit cf57f5f

File tree

6 files changed

+472
-2
lines changed

6 files changed

+472
-2
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
//===----------------------------------------------------------------------===//
10+
/// \file runtime/kernel/make_boxed_from_unboxed_functor.h
11+
/// Defines a template that can be used to create a boxed version of an unboxed
12+
/// functor.
13+
/// Example usage:
14+
/// ```
15+
/// Tensor&
16+
/// my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor&
17+
/// out) {
18+
/// // ...
19+
/// return out;
20+
/// }
21+
///
22+
/// Kernel my_kernel = Kernel.make_boxed_kernel("my_ns::my_op",
23+
/// EXECUTORCH_FN(my_op)); register_kernels({my_kernel});
24+
/// ```
25+
///
26+
/// The trick here is to convert each EValue to inferred argument type. This
27+
/// uses a lot of C++17 features.
28+
//===----------------------------------------------------------------------===//
29+
30+
#pragma once
31+
#if __cplusplus < 201703L
32+
#error "This header requires C++17"
33+
#endif
34+
35+
#include <executorch/runtime/core/evalue.h>
36+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
37+
#include <executorch/runtime/kernel/type_list.h>
38+
#include <cstdlib>
39+
#include <memory>
40+
#include <type_traits>
41+
#include <typeinfo>
42+
43+
namespace torch {
44+
namespace executor {
45+
46+
class KernelRuntimeContext; // Forward declaration
47+
using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove
48+
49+
// Check if a given type is a function
50+
template <class T>
51+
struct is_function_type : std::false_type {};
52+
template <class Result, class... Args>
53+
struct is_function_type<Result(Args...)> : std::true_type {};
54+
template <class T>
55+
using is_function_type_t = typename is_function_type<T>::type;
56+
57+
// A compile-time wrapper around a function pointer
58+
template <class FuncType_, FuncType_* func_ptr_>
59+
struct CompileTimeFunctionPointer final {
60+
static_assert(
61+
is_function_type<FuncType_>::value,
62+
"EXECUTORCH_FN can only wrap function types.");
63+
using FuncType = FuncType_;
64+
65+
static constexpr FuncType* func_ptr() {
66+
return func_ptr_;
67+
}
68+
};
69+
70+
// Check if a given type is a compile-time function pointer
71+
template <class T>
72+
struct is_compile_time_function_pointer : std::false_type {};
73+
template <class FuncType, FuncType* func_ptr>
74+
struct is_compile_time_function_pointer<
75+
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
76+
77+
#define EXECUTORCH_FN_TYPE(func) \
78+
CompileTimeFunctionPointer< \
79+
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
80+
func>
81+
#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)()
82+
83+
/**
84+
* strip_class: helper to remove the class type from pointers to `operator()`.
85+
*/
86+
template <typename T>
87+
struct strip_class {};
88+
template <typename Class, typename Result, typename... Args>
89+
struct strip_class<Result (Class::*)(Args...)> {
90+
using type = Result(Args...);
91+
};
92+
template <typename Class, typename Result, typename... Args>
93+
struct strip_class<Result (Class::*)(Args...) const> {
94+
using type = Result(Args...);
95+
};
96+
template <typename T>
97+
using strip_class_t = typename strip_class<T>::type;
98+
99+
/**
100+
* Access information about result type or arguments from a function type.
101+
* Example:
102+
* using A = function_traits<int (float, double)>::return_type // A == int
103+
* using A = function_traits<int (float, double)>::parameter_types::tuple_type
104+
* // A == tuple<float, double>
105+
*/
106+
template <class Func>
107+
struct function_traits {
108+
static_assert(
109+
!std::is_same<Func, Func>::value,
110+
"In function_traits<Func>, Func must be a plain function type.");
111+
};
112+
template <class Result, class... Args>
113+
struct function_traits<Result(Args...)> {
114+
using func_type = Result(Args...);
115+
using return_type = Result;
116+
using parameter_types = typelist<Args...>;
117+
static constexpr auto number_of_parameters = sizeof...(Args);
118+
};
119+
120+
/**
121+
* infer_function_traits: creates a `function_traits` type for a simple
122+
* function (pointer) or functor (lambda/struct). Currently does not support
123+
* class methods.
124+
*/
125+
template <typename Functor>
126+
struct infer_function_traits {
127+
using type = function_traits<strip_class_t<decltype(&Functor::operator())>>;
128+
};
129+
template <typename Result, typename... Args>
130+
struct infer_function_traits<Result (*)(Args...)> {
131+
using type = function_traits<Result(Args...)>;
132+
};
133+
template <typename Result, typename... Args>
134+
struct infer_function_traits<Result(Args...)> {
135+
using type = function_traits<Result(Args...)>;
136+
};
137+
template <typename T>
138+
using infer_function_traits_t = typename infer_function_traits<T>::type;
139+
140+
// evalue_to_arg
141+
template <class T>
142+
struct decay_if_not_tensor final {
143+
using type = std::decay_t<T>;
144+
};
145+
template <>
146+
struct decay_if_not_tensor<exec_aten::Tensor&> final {
147+
using type = exec_aten::Tensor&;
148+
};
149+
template <>
150+
struct decay_if_not_tensor<const exec_aten::Tensor&> final {
151+
using type = const exec_aten::Tensor&;
152+
};
153+
154+
template <class T>
155+
struct evalue_to_arg final {
156+
static T call(EValue& v) {
157+
return std::move(v).to<T>();
158+
}
159+
};
160+
161+
template <>
162+
struct evalue_to_arg<exec_aten::Tensor&> final {
163+
static exec_aten::Tensor& call(EValue& v) {
164+
return v.toTensor();
165+
}
166+
};
167+
168+
template <>
169+
struct evalue_to_arg<const exec_aten::Tensor&> final {
170+
static const exec_aten::Tensor& call(EValue& v) {
171+
return v.toTensor();
172+
}
173+
};
174+
// Call functor with args from stack
175+
176+
template <class Functor, size_t... evalue_arg_indices, typename... ArgTypes>
177+
void call_functor_with_args_from_stack_(
178+
RuntimeContext& ctx,
179+
EValue** stack,
180+
std::index_sequence<evalue_arg_indices...>,
181+
typelist<ArgTypes...>*) {
182+
(*Functor::func_ptr())(
183+
ctx,
184+
evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call(
185+
*stack[evalue_arg_indices])...);
186+
}
187+
188+
/**
189+
* WrapUnboxedIntoFunctor: Given a function pointer, wrap it into a functor that
190+
* takes EValues as input and returns void. The wrapped functor will unbox all
191+
* inputs and forward them to unboxed kernel.
192+
*/
193+
template <class FuncType>
194+
struct WrapUnboxedIntoFunctor {
195+
static_assert(
196+
is_compile_time_function_pointer<FuncType>::value,
197+
"Can't handle function other than EXECUTORCH_FN");
198+
using TrueType = typename FuncType::FuncType;
199+
using ReturnType = typename infer_function_traits_t<TrueType>::return_type;
200+
using ArgsType = typename infer_function_traits_t<TrueType>::parameter_types;
201+
// check if the first argument is RuntimeContext, if so, remove it
202+
static constexpr bool first_arg_is_context = std::is_same<
203+
RuntimeContext,
204+
std::remove_reference_t<head_with_default_t<void, ArgsType>>>::value;
205+
using ContextRemovedArgsType = std::conditional_t<
206+
first_arg_is_context,
207+
drop_if_nonempty_t<ArgsType, 1>,
208+
ArgsType>;
209+
210+
static void call(RuntimeContext& ctx, EValue** stack) {
211+
constexpr size_t num_inputs = size<ContextRemovedArgsType>::value;
212+
return call_functor_with_args_from_stack_<FuncType>(
213+
ctx,
214+
stack,
215+
std::make_index_sequence<num_inputs>(),
216+
static_cast<ContextRemovedArgsType*>(nullptr));
217+
}
218+
};
219+
220+
} // namespace executor
221+
} // namespace torch

runtime/kernel/operator_registry.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
#include <executorch/runtime/core/function_ref.h>
1818
#include <executorch/runtime/platform/compiler.h>
1919
#include <executorch/runtime/platform/platform.h>
20+
#if __cplusplus >= 201703L
21+
#include <executorch/runtime/kernel/make_boxed_from_unboxed_functor.h>
22+
#endif
2023
// Debug switch for operator registry
2124
#if defined(ET_OP_REGISTRY_DEBUG)
2225
#include <ostream>
@@ -200,6 +203,13 @@ struct Kernel {
200203
explicit Kernel(const char* name, KernelKey key, OpFunction func)
201204
: name_(name), kernel_key_(key), op_(func) {}
202205

206+
#if __cplusplus >= 201703L
207+
template <typename FuncType>
208+
static inline Kernel make_boxed_kernel(const char* name, FuncType) {
209+
return Kernel(name, WrapUnboxedIntoFunctor<FuncType>::call);
210+
}
211+
#endif
212+
203213
Kernel() {}
204214
};
205215

runtime/kernel/targets.bzl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ def define_common_targets():
1111
runtime.cxx_library(
1212
name = "operator_registry",
1313
srcs = ["operator_registry.cpp"],
14-
exported_headers = ["operator_registry.h"],
14+
exported_headers = [
15+
"operator_registry.h",
16+
"make_boxed_from_unboxed_functor.h",
17+
"type_list.h",
18+
],
1519
visibility = [
1620
"//executorch/...",
1721
"@EXECUTORCH_CLIENTS",
@@ -26,7 +30,11 @@ def define_common_targets():
2630
runtime.cxx_library(
2731
name = "operator_registry_TWO_KERNELS_TEST_ONLY",
2832
srcs = ["operator_registry.cpp"],
29-
exported_headers = ["operator_registry.h"],
33+
exported_headers = [
34+
"operator_registry.h",
35+
"make_boxed_from_unboxed_functor.h",
36+
"type_list.h",
37+
],
3038
visibility = [
3139
"//executorch/...",
3240
"@EXECUTORCH_CLIENTS",
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/error.h>
10+
#include <executorch/runtime/core/portable_type/tensor.h>
11+
#include <executorch/runtime/kernel/kernel_runtime_context.h>
12+
#include <executorch/runtime/kernel/operator_registry.h>
13+
#include <executorch/runtime/platform/runtime.h>
14+
#include <gtest/gtest.h>
15+
16+
using namespace ::testing;
17+
using RuntimeContext = torch::executor::KernelRuntimeContext;
18+
using namespace torch::executor;
19+
20+
Tensor& my_op_out(RuntimeContext& ctx, const Tensor& a, Tensor& out) {
21+
(void)ctx;
22+
(void)a;
23+
return out;
24+
}
25+
26+
Tensor& set_1_out(RuntimeContext& ctx, Tensor& out) {
27+
(void)ctx;
28+
out.mutable_data_ptr<int32_t>()[0] = 1;
29+
return out;
30+
}
31+
32+
class MakeBoxedFromUnboxedFunctorTest : public ::testing::Test {
33+
public:
34+
void SetUp() override {
35+
torch::executor::runtime_init();
36+
}
37+
};
38+
39+
TEST_F(MakeBoxedFromUnboxedFunctorTest, Basic) {
40+
Kernel my_kernel =
41+
Kernel::make_boxed_kernel("my_ns::my_op.out", EXECUTORCH_FN(my_op_out));
42+
ArrayRef<Kernel> kernels_array = ArrayRef<Kernel>(my_kernel);
43+
// @lint-ignore CLANGTIDY
44+
auto s1 = register_kernels(kernels_array);
45+
EXPECT_TRUE(hasOpsFn("my_ns::my_op.out"));
46+
}
47+
48+
TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxLogicWorks) {
49+
Kernel my_kernel =
50+
Kernel::make_boxed_kernel("my_ns::set_1.out", EXECUTORCH_FN(set_1_out));
51+
ArrayRef<Kernel> kernels_array = ArrayRef<Kernel>(my_kernel);
52+
// @lint-ignore CLANGTIDY
53+
auto s1 = register_kernels(kernels_array);
54+
EXPECT_TRUE(hasOpsFn("my_ns::set_1.out"));
55+
56+
// prepare out tensor
57+
TensorImpl::SizesType sizes[1] = {5};
58+
TensorImpl::DimOrderType dim_order[1] = {0};
59+
int32_t data[5] = {0, 0, 0, 0, 0};
60+
auto a_impl = TensorImpl(ScalarType::Int, 1, sizes, data, dim_order, nullptr);
61+
auto a = Tensor(&a_impl);
62+
63+
// get boxed callable
64+
auto fn = getOpsFn("my_ns::set_1.out");
65+
66+
// run it
67+
RuntimeContext context;
68+
EValue values[1];
69+
values[0] = a;
70+
EValue* stack[1];
71+
stack[0] = &values[0];
72+
73+
fn(context, stack);
74+
75+
// check result
76+
EXPECT_EQ(a.const_data_ptr<int32_t>()[0], 1);
77+
}

runtime/kernel/test/targets.bzl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@ def define_common_targets():
3030
],
3131
)
3232

33+
runtime.cxx_test(
34+
name = "make_boxed_from_unboxed_functor_test",
35+
srcs = [
36+
"make_boxed_from_unboxed_functor_test.cpp",
37+
],
38+
deps = [
39+
"//executorch/runtime/kernel:operator_registry",
40+
"//executorch/runtime/kernel:kernel_runtime_context",
41+
"//executorch/runtime/core/exec_aten:lib",
42+
],
43+
)
44+
3345
et_operator_library(
3446
name = "executorch_all_ops",
3547
include_all_operators = True,

0 commit comments

Comments
 (0)