Skip to content

Commit ad0b92b

Browse files
committed
Refactor make_boxed_from_unboxed_functor.h
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b1c1598 Pull Request resolved: #2184
1 parent 9e2cf2f commit ad0b92b

File tree

3 files changed

+142
-114
lines changed

3 files changed

+142
-114
lines changed

extension/pybindings/cpp_extension.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,32 @@
1-
from torch.utils import cpp_extension
21
import os
32

3+
from torch.utils import cpp_extension
4+
45
_HERE = os.path.abspath(__file__)
56
_EXECUTORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
67

7-
def load_inline(name,
8-
cpp_sources,
9-
functions=None,
10-
extra_cflags=None,
11-
extra_ldflags=None,
12-
extra_include_paths=None,
13-
build_directory=None,
14-
verbose=False,
15-
is_python_module=True,
16-
with_pytorch_error_handling=True,
17-
keep_intermediates=True,
18-
use_pch=False):
8+
9+
def load_inline(
10+
name,
11+
cpp_sources,
12+
functions=None,
13+
extra_cflags=None,
14+
extra_ldflags=None,
15+
extra_include_paths=None,
16+
build_directory=None,
17+
verbose=False,
18+
is_python_module=True,
19+
with_pytorch_error_handling=True,
20+
keep_intermediates=True,
21+
use_pch=False,
22+
):
1923
# Register the code into PyTorch
2024
aten_extra_cflags = ["-DUSE_ATEN_LIB"] + (extra_cflags if extra_cflags else [])
21-
extra_ldflags = [f"-L{_EXECUTORCH_PATH}", f"-Wl,-rpath,{_EXECUTORCH_PATH}", "-lexecutorch"] + (extra_ldflags if extra_ldflags else [])
25+
extra_ldflags = [
26+
f"-L{_EXECUTORCH_PATH}",
27+
f"-Wl,-rpath,{_EXECUTORCH_PATH}",
28+
"-lexecutorch",
29+
] + (extra_ldflags if extra_ldflags else [])
2230
module = cpp_extension.load_inline(
2331
name,
2432
cpp_sources,
@@ -37,13 +45,13 @@ def load_inline(name,
3745
cpp_extension.load_inline(
3846
name,
3947
cpp_sources,
40-
functions=None, # leave this out since we are not passing out any python module
48+
functions=None, # leave this out since we are not passing out any python module
4149
extra_cflags=extra_cflags,
4250
extra_ldflags=extra_ldflags,
4351
extra_include_paths=extra_include_paths,
4452
build_directory=build_directory,
4553
verbose=verbose,
46-
is_python_module=False, # don't register as a python module. Load shared library as a side effect.
54+
is_python_module=False, # don't register as a python module. Load shared library as a side effect.
4755
with_pytorch_error_handling=with_pytorch_error_handling,
4856
keep_intermediates=keep_intermediates,
4957
use_pch=use_pch,

runtime/kernel/make_boxed_from_unboxed_functor.h

Lines changed: 4 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
/// return out;
2020
/// }
2121
///
22-
/// Kernel my_kernel = Kernel.make_boxed_kernel("my_ns::my_op",
23-
/// EXECUTORCH_FN(my_op)); register_kernels({my_kernel});
22+
/// Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op",
23+
/// EXECUTORCH_FN(my_op));
24+
/// static auto res = register_kernels({my_kernel});
2425
/// ```
2526
///
2627
/// The trick here is to convert each EValue to inferred argument type. This
@@ -34,109 +35,14 @@
3435

3536
#include <executorch/runtime/core/evalue.h>
3637
#include <executorch/runtime/core/exec_aten/exec_aten.h>
37-
#include <executorch/runtime/kernel/type_list.h>
38-
#include <cstdlib>
39-
#include <memory>
40-
#include <type_traits>
41-
#include <typeinfo>
38+
#include <executorch/runtime/kernel/meta_programming.h>
4239

4340
namespace torch {
4441
namespace executor {
4542

4643
class KernelRuntimeContext; // Forward declaration
4744
using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove
4845

49-
// Check if a given type is a function
50-
template <class T>
51-
struct is_function_type : std::false_type {};
52-
template <class Result, class... Args>
53-
struct is_function_type<Result(Args...)> : std::true_type {};
54-
template <class T>
55-
using is_function_type_t = typename is_function_type<T>::type;
56-
57-
// A compile-time wrapper around a function pointer
58-
template <class FuncType_, FuncType_* func_ptr_>
59-
struct CompileTimeFunctionPointer final {
60-
static_assert(
61-
is_function_type<FuncType_>::value,
62-
"EXECUTORCH_FN can only wrap function types.");
63-
using FuncType = FuncType_;
64-
65-
static constexpr FuncType* func_ptr() {
66-
return func_ptr_;
67-
}
68-
};
69-
70-
// Check if a given type is a compile-time function pointer
71-
template <class T>
72-
struct is_compile_time_function_pointer : std::false_type {};
73-
template <class FuncType, FuncType* func_ptr>
74-
struct is_compile_time_function_pointer<
75-
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
76-
77-
#define EXECUTORCH_FN_TYPE(func) \
78-
CompileTimeFunctionPointer< \
79-
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
80-
func>
81-
#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)()
82-
83-
/**
84-
* strip_class: helper to remove the class type from pointers to `operator()`.
85-
*/
86-
template <typename T>
87-
struct strip_class {};
88-
template <typename Class, typename Result, typename... Args>
89-
struct strip_class<Result (Class::*)(Args...)> {
90-
using type = Result(Args...);
91-
};
92-
template <typename Class, typename Result, typename... Args>
93-
struct strip_class<Result (Class::*)(Args...) const> {
94-
using type = Result(Args...);
95-
};
96-
template <typename T>
97-
using strip_class_t = typename strip_class<T>::type;
98-
99-
/**
100-
* Access information about result type or arguments from a function type.
101-
* Example:
102-
* using A = function_traits<int (float, double)>::return_type // A == int
103-
* using A = function_traits<int (float, double)>::parameter_types::tuple_type
104-
* // A == tuple<float, double>
105-
*/
106-
template <class Func>
107-
struct function_traits {
108-
static_assert(
109-
!std::is_same<Func, Func>::value,
110-
"In function_traits<Func>, Func must be a plain function type.");
111-
};
112-
template <class Result, class... Args>
113-
struct function_traits<Result(Args...)> {
114-
using func_type = Result(Args...);
115-
using return_type = Result;
116-
using parameter_types = typelist<Args...>;
117-
static constexpr auto number_of_parameters = sizeof...(Args);
118-
};
119-
120-
/**
121-
* infer_function_traits: creates a `function_traits` type for a simple
122-
* function (pointer) or functor (lambda/struct). Currently does not support
123-
* class methods.
124-
*/
125-
template <typename Functor>
126-
struct infer_function_traits {
127-
using type = function_traits<strip_class_t<decltype(&Functor::operator())>>;
128-
};
129-
template <typename Result, typename... Args>
130-
struct infer_function_traits<Result (*)(Args...)> {
131-
using type = function_traits<Result(Args...)>;
132-
};
133-
template <typename Result, typename... Args>
134-
struct infer_function_traits<Result(Args...)> {
135-
using type = function_traits<Result(Args...)>;
136-
};
137-
template <typename T>
138-
using infer_function_traits_t = typename infer_function_traits<T>::type;
139-
14046
// evalue_to_arg
14147
template <class T>
14248
struct decay_if_not_tensor final {

runtime/kernel/meta_programming.h

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#pragma once
9+
#if __cplusplus < 201703L
10+
#error "This header requires C++17"
11+
#endif
12+
13+
#include <executorch/runtime/kernel/type_list.h>
14+
#include <cstdlib>
15+
#include <memory>
16+
#include <type_traits>
17+
#include <typeinfo>
18+
19+
namespace torch {
20+
namespace executor {
21+
22+
// Check if a given type is a function
23+
template <class T>
24+
struct is_function_type : std::false_type {};
25+
template <class Result, class... Args>
26+
struct is_function_type<Result(Args...)> : std::true_type {};
27+
template <class T>
28+
using is_function_type_t = typename is_function_type<T>::type;
29+
30+
// A compile-time wrapper around a function pointer
31+
template <class FuncType_, FuncType_* func_ptr_>
32+
struct CompileTimeFunctionPointer final {
33+
static_assert(
34+
is_function_type<FuncType_>::value,
35+
"EXECUTORCH_FN can only wrap function types.");
36+
using FuncType = FuncType_;
37+
38+
static constexpr FuncType* func_ptr() {
39+
return func_ptr_;
40+
}
41+
};
42+
43+
// Check if a given type is a compile-time function pointer
44+
template <class T>
45+
struct is_compile_time_function_pointer : std::false_type {};
46+
template <class FuncType, FuncType* func_ptr>
47+
struct is_compile_time_function_pointer<
48+
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
49+
50+
#define EXECUTORCH_FN_TYPE(func) \
51+
CompileTimeFunctionPointer< \
52+
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
53+
func>
54+
#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)()
55+
56+
/**
57+
* strip_class: helper to remove the class type from pointers to `operator()`.
58+
*/
59+
template <typename T>
60+
struct strip_class {};
61+
template <typename Class, typename Result, typename... Args>
62+
struct strip_class<Result (Class::*)(Args...)> {
63+
using type = Result(Args...);
64+
};
65+
template <typename Class, typename Result, typename... Args>
66+
struct strip_class<Result (Class::*)(Args...) const> {
67+
using type = Result(Args...);
68+
};
69+
template <typename T>
70+
using strip_class_t = typename strip_class<T>::type;
71+
72+
/**
73+
* Access information about result type or arguments from a function type.
74+
* Example:
75+
* using A = function_traits<int (float, double)>::return_type // A == int
76+
* using A = function_traits<int (float, double)>::parameter_types::tuple_type
77+
* // A == tuple<float, double>
78+
*/
79+
template <class Func>
80+
struct function_traits {
81+
static_assert(
82+
!std::is_same<Func, Func>::value,
83+
"In function_traits<Func>, Func must be a plain function type.");
84+
};
85+
template <class Result, class... Args>
86+
struct function_traits<Result(Args...)> {
87+
using func_type = Result(Args...);
88+
using return_type = Result;
89+
using parameter_types = typelist<Args...>;
90+
static constexpr auto number_of_parameters = sizeof...(Args);
91+
};
92+
93+
/**
94+
* infer_function_traits: creates a `function_traits` type for a simple
95+
* function (pointer) or functor (lambda/struct). Currently does not support
96+
* class methods.
97+
*/
98+
template <typename Functor>
99+
struct infer_function_traits {
100+
using type = function_traits<strip_class_t<decltype(&Functor::operator())>>;
101+
};
102+
template <typename Result, typename... Args>
103+
struct infer_function_traits<Result (*)(Args...)> {
104+
using type = function_traits<Result(Args...)>;
105+
};
106+
template <typename Result, typename... Args>
107+
struct infer_function_traits<Result(Args...)> {
108+
using type = function_traits<Result(Args...)>;
109+
};
110+
template <typename T>
111+
using infer_function_traits_t = typename infer_function_traits<T>::type;
112+
113+
} // namespace executor
114+
} // namespace torch

0 commit comments

Comments
 (0)