|
19 | 19 | /// return out;
|
20 | 20 | /// }
|
21 | 21 | ///
|
22 |
| -/// Kernel my_kernel = Kernel.make_boxed_kernel("my_ns::my_op", |
23 |
| -/// EXECUTORCH_FN(my_op)); register_kernels({my_kernel}); |
| 22 | +/// Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op", |
| 23 | +/// EXECUTORCH_FN(my_op)); |
| 24 | +/// static auto res = register_kernels({my_kernel}); |
24 | 25 | /// ```
|
25 | 26 | ///
|
26 | 27 | /// The trick here is to convert each EValue to inferred argument type. This
|
|
34 | 35 |
|
35 | 36 | #include <executorch/runtime/core/evalue.h>
|
36 | 37 | #include <executorch/runtime/core/exec_aten/exec_aten.h>
|
37 |
| -#include <executorch/runtime/kernel/type_list.h> |
38 |
| -#include <cstdlib> |
39 |
| -#include <memory> |
40 |
| -#include <type_traits> |
41 |
| -#include <typeinfo> |
| 38 | +#include <executorch/runtime/kernel/meta_programming.h> |
42 | 39 |
|
43 | 40 | namespace torch {
|
44 | 41 | namespace executor {
|
45 | 42 |
|
46 | 43 | class KernelRuntimeContext; // Forward declaration
|
47 | 44 | using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove
|
48 | 45 |
|
49 |
| -// Check if a given type is a function |
50 |
| -template <class T> |
51 |
| -struct is_function_type : std::false_type {}; |
52 |
| -template <class Result, class... Args> |
53 |
| -struct is_function_type<Result(Args...)> : std::true_type {}; |
54 |
| -template <class T> |
55 |
| -using is_function_type_t = typename is_function_type<T>::type; |
56 |
| - |
57 |
| -// A compile-time wrapper around a function pointer |
58 |
| -template <class FuncType_, FuncType_* func_ptr_> |
59 |
| -struct CompileTimeFunctionPointer final { |
60 |
| - static_assert( |
61 |
| - is_function_type<FuncType_>::value, |
62 |
| - "EXECUTORCH_FN can only wrap function types."); |
63 |
| - using FuncType = FuncType_; |
64 |
| - |
65 |
| - static constexpr FuncType* func_ptr() { |
66 |
| - return func_ptr_; |
67 |
| - } |
68 |
| -}; |
69 |
| - |
70 |
| -// Check if a given type is a compile-time function pointer |
71 |
| -template <class T> |
72 |
| -struct is_compile_time_function_pointer : std::false_type {}; |
73 |
| -template <class FuncType, FuncType* func_ptr> |
74 |
| -struct is_compile_time_function_pointer< |
75 |
| - CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {}; |
76 |
| - |
77 |
| -#define EXECUTORCH_FN_TYPE(func) \ |
78 |
| - CompileTimeFunctionPointer< \ |
79 |
| - std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \ |
80 |
| - func> |
81 |
| -#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)() |
82 |
| - |
83 |
| -/** |
84 |
| - * strip_class: helper to remove the class type from pointers to `operator()`. |
85 |
| - */ |
86 |
| -template <typename T> |
87 |
| -struct strip_class {}; |
88 |
| -template <typename Class, typename Result, typename... Args> |
89 |
| -struct strip_class<Result (Class::*)(Args...)> { |
90 |
| - using type = Result(Args...); |
91 |
| -}; |
92 |
| -template <typename Class, typename Result, typename... Args> |
93 |
| -struct strip_class<Result (Class::*)(Args...) const> { |
94 |
| - using type = Result(Args...); |
95 |
| -}; |
96 |
| -template <typename T> |
97 |
| -using strip_class_t = typename strip_class<T>::type; |
98 |
| - |
99 |
| -/** |
100 |
| - * Access information about result type or arguments from a function type. |
101 |
| - * Example: |
102 |
| - * using A = function_traits<int (float, double)>::return_type // A == int |
103 |
| - * using A = function_traits<int (float, double)>::parameter_types::tuple_type |
104 |
| - * // A == tuple<float, double> |
105 |
| - */ |
106 |
| -template <class Func> |
107 |
| -struct function_traits { |
108 |
| - static_assert( |
109 |
| - !std::is_same<Func, Func>::value, |
110 |
| - "In function_traits<Func>, Func must be a plain function type."); |
111 |
| -}; |
112 |
| -template <class Result, class... Args> |
113 |
| -struct function_traits<Result(Args...)> { |
114 |
| - using func_type = Result(Args...); |
115 |
| - using return_type = Result; |
116 |
| - using parameter_types = typelist<Args...>; |
117 |
| - static constexpr auto number_of_parameters = sizeof...(Args); |
118 |
| -}; |
119 |
| - |
120 |
| -/** |
121 |
| - * infer_function_traits: creates a `function_traits` type for a simple |
122 |
| - * function (pointer) or functor (lambda/struct). Currently does not support |
123 |
| - * class methods. |
124 |
| - */ |
125 |
| -template <typename Functor> |
126 |
| -struct infer_function_traits { |
127 |
| - using type = function_traits<strip_class_t<decltype(&Functor::operator())>>; |
128 |
| -}; |
129 |
| -template <typename Result, typename... Args> |
130 |
| -struct infer_function_traits<Result (*)(Args...)> { |
131 |
| - using type = function_traits<Result(Args...)>; |
132 |
| -}; |
133 |
| -template <typename Result, typename... Args> |
134 |
| -struct infer_function_traits<Result(Args...)> { |
135 |
| - using type = function_traits<Result(Args...)>; |
136 |
| -}; |
137 |
| -template <typename T> |
138 |
| -using infer_function_traits_t = typename infer_function_traits<T>::type; |
139 |
| - |
140 | 46 | // evalue_to_arg
|
141 | 47 | template <class T>
|
142 | 48 | struct decay_if_not_tensor final {
|
|
0 commit comments