Skip to content

[kernel] Add template based unboxing #1284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions extension/kernel_util/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
This header file `make_boxed_from_unboxed_functor.h` defines a template that can be used to create a boxed version of an unboxed functor. It is part of the executorch extension in the torch namespace.
## Requirements
This header requires C++17 or later.
## Usage
The template takes an unboxed function pointer and wraps it into a functor that takes `RuntimeContext` and `EValues` as inputs and returns void. The wrapped functor will unbox all inputs and forward them to the unboxed kernel.
Here is an example of how to use the template:
```C++
Tensor& my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor& out) {
// ...
return out;
}
Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op", EXECUTORCH_FN(my_op));
static auto res = register_kernels({my_kernel});
```
Alternatively, you can use the EXECUTORCH_LIBRARY macro to simplify the process:
```C++
EXECUTORCH_LIBRARY(my_ns, "my_op", my_op);
```
## Details
The template uses a lot of C++17 features to convert each EValue to the inferred argument type. It checks if the first argument is `RuntimeContext`, and if so, it removes it. The call method of the `WrapUnboxedIntoFunctor` struct calls the unboxed function with the corresponding arguments.
The `EXECUTORCH_LIBRARY` macro registers the kernel for the operation and stores the result in a static variable.
## Note
The `RuntimeContext` is a placeholder for a context that will be passed to kernels. It is currently empty, but it is planned to be used for kernel temp memory allocation and error handling in the future.
8 changes: 8 additions & 0 deletions extension/kernel_util/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
145 changes: 145 additions & 0 deletions extension/kernel_util/make_boxed_from_unboxed_functor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

//===----------------------------------------------------------------------===//
/// \file extension/kernel_util/make_boxed_from_unboxed_functor.h
/// Defines a template that can be used to create a boxed version of an unboxed
/// functor.
/// Example usage:
/// ```
/// Tensor&
/// my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor&
/// out) {
/// // ...
/// return out;
/// }
///
/// Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op",
/// EXECUTORCH_FN(my_op));
/// static auto res = register_kernels({my_kernel});
/// ```
/// Or simply:
/// ```
/// EXECUTORCH_LIBRARY(my_ns, "my_op", my_op);
/// ```
///
/// The trick here is to convert each EValue to inferred argument type. This
/// uses a lot of C++17 features.
//===----------------------------------------------------------------------===//

#pragma once
#if __cplusplus < 201703L
#error "This header requires C++17"
#endif

#include <executorch/extension/kernel_util/meta_programming.h>
#include <executorch/extension/kernel_util/type_list.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/kernel/operator_registry.h>
#include <cstdlib>
#include <memory>
#include <type_traits>
#include <typeinfo>

namespace torch {
namespace executor {

class KernelRuntimeContext; // Forward declaration
using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove

// evalue_to_arg
template <class T>
struct decay_if_not_tensor final {
using type = std::decay_t<T>;
};
template <>
struct decay_if_not_tensor<exec_aten::Tensor&> final {
using type = exec_aten::Tensor&;
};
template <>
struct decay_if_not_tensor<const exec_aten::Tensor&> final {
using type = const exec_aten::Tensor&;
};

template <class T>
struct evalue_to_arg final {
static T call(EValue& v) {
return std::move(v).to<T>();
}
};

template <>
struct evalue_to_arg<exec_aten::Tensor&> final {
static exec_aten::Tensor& call(EValue& v) {
return v.toTensor();
}
};

template <>
struct evalue_to_arg<const exec_aten::Tensor&> final {
static const exec_aten::Tensor& call(EValue& v) {
return v.toTensor();
}
};
// Call functor with args from stack

template <class Functor, size_t... evalue_arg_indices, typename... ArgTypes>
void call_functor_with_args_from_stack_(
RuntimeContext& ctx,
EValue** stack,
std::index_sequence<evalue_arg_indices...>,
typelist<ArgTypes...>*) {
(*Functor::func_ptr())(
ctx,
evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call(
*stack[evalue_arg_indices])...);
}

/**
* WrapUnboxedIntoFunctor: Given a function pointer, wrap it into a functor that
* takes EValues as input and returns void. The wrapped functor will unbox all
* inputs and forward them to unboxed kernel.
*/
template <class FuncType>
struct WrapUnboxedIntoFunctor {
static_assert(
is_compile_time_function_pointer<FuncType>::value,
"Can't handle function other than EXECUTORCH_FN");
using TrueType = typename FuncType::FuncType;
using ReturnType = typename infer_function_traits_t<TrueType>::return_type;
using ArgsType = typename infer_function_traits_t<TrueType>::parameter_types;
// check if the first argument is RuntimeContext, if so, remove it
static constexpr bool first_arg_is_context = std::is_same<
RuntimeContext,
std::remove_reference_t<head_with_default_t<void, ArgsType>>>::value;
using ContextRemovedArgsType = std::conditional_t<
first_arg_is_context,
drop_if_nonempty_t<ArgsType, 1>,
ArgsType>;

static void call(RuntimeContext& ctx, EValue** stack) {
constexpr size_t num_inputs = size<ContextRemovedArgsType>::value;
return call_functor_with_args_from_stack_<FuncType>(
ctx,
stack,
std::make_index_sequence<num_inputs>(),
static_cast<ContextRemovedArgsType*>(nullptr));
}
};

template <typename FuncType>
static Kernel make_boxed_kernel(const char* name, FuncType) {
return Kernel(name, WrapUnboxedIntoFunctor<FuncType>::call);
}

#define EXECUTORCH_LIBRARY(ns, op_name, func) \
static auto res_##ns = register_kernels( \
make_boxed_kernel(#ns "::" op_name, EXECUTORCH_FN(func)))
} // namespace executor
} // namespace torch
115 changes: 115 additions & 0 deletions extension/kernel_util/meta_programming.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once
#if __cplusplus < 201703L
#error "This header requires C++17"
#endif

#include <executorch/extension/kernel_util/type_list.h>
#include <cstdlib>
#include <memory>
#include <type_traits>
#include <typeinfo>

namespace torch {
namespace executor {

// Check if a given type is a function
template <class T>
struct is_function_type : std::false_type {};
template <class Result, class... Args>
struct is_function_type<Result(Args...)> : std::true_type {};
template <class T>
using is_function_type_t = typename is_function_type<T>::type;

// A compile-time wrapper around a function pointer
template <class FuncType_, FuncType_* func_ptr_>
struct CompileTimeFunctionPointer final {
static_assert(
is_function_type<FuncType_>::value,
"EXECUTORCH_FN can only wrap function types.");
using FuncType = FuncType_;

static constexpr FuncType* func_ptr() {
return func_ptr_;
}
};

// Check if a given type is a compile-time function pointer
template <class T>
struct is_compile_time_function_pointer : std::false_type {};
template <class FuncType, FuncType* func_ptr>
struct is_compile_time_function_pointer<
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};

#define EXECUTORCH_FN_TYPE(func) \
CompileTimeFunctionPointer< \
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
func>
#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)()

/**
* strip_class: helper to remove the class type from pointers to `operator()`.
*/
template <typename T>
struct strip_class {};
template <typename Class, typename Result, typename... Args>
struct strip_class<Result (Class::*)(Args...)> {
using type = Result(Args...);
};
template <typename Class, typename Result, typename... Args>
struct strip_class<Result (Class::*)(Args...) const> {
using type = Result(Args...);
};
template <typename T>
using strip_class_t = typename strip_class<T>::type;

/**
* Access information about result type or arguments from a function type.
* Example:
* using A = function_traits<int (float, double)>::return_type // A == int
* using A = function_traits<int (float, double)>::parameter_types::tuple_type
* // A == tuple<float, double>
*/
template <class Func>
struct function_traits {
static_assert(
!std::is_same<Func, Func>::value,
"In function_traits<Func>, Func must be a plain function type.");
};
template <class Result, class... Args>
struct function_traits<Result(Args...)> {
using func_type = Result(Args...);
using return_type = Result;
using parameter_types = typelist<Args...>;
static constexpr auto number_of_parameters = sizeof...(Args);
};

/**
* infer_function_traits: creates a `function_traits` type for a simple
* function (pointer) or functor (lambda/struct). Currently does not support
* class methods.
*/
template <typename Functor>
struct infer_function_traits {
using type = function_traits<strip_class_t<decltype(&Functor::operator())>>;
};
template <typename Result, typename... Args>
struct infer_function_traits<Result (*)(Args...)> {
using type = function_traits<Result(Args...)>;
};
template <typename Result, typename... Args>
struct infer_function_traits<Result(Args...)> {
using type = function_traits<Result(Args...)>;
};
template <typename T>
using infer_function_traits_t = typename infer_function_traits<T>::type;

} // namespace executor
} // namespace torch
29 changes: 29 additions & 0 deletions extension/kernel_util/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.

The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""

runtime.cxx_library(
name = "kernel_util",
srcs = [],
exported_headers = [
"make_boxed_from_unboxed_functor.h",
"meta_programming.h",
"type_list.h",
],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
"//executorch/runtime/core:core",
"//executorch/runtime/core:evalue",
"//executorch/runtime/kernel:kernel_includes",
"//executorch/runtime/kernel:kernel_runtime_context",
"//executorch/runtime/kernel:operator_registry",
],
)
8 changes: 8 additions & 0 deletions extension/kernel_util/test/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
Loading