Skip to content

Commit 869704e

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Fork FunctionRef for pytree (#554)
Summary: Pull Request resolved: #554 Now there are two FunctionRef classes, (A) torch::executor::pytree::FunctionRef in executorch/extension/pytree/function_ref.h (B) torch::executor::FunctionRef in executor/runtime/core/function_ref.h Rationale: We are aiming for version (B), compared to (A), to have features going forward, and also planning to eventually delete it completely. Reviewed By: dbort Differential Revision: D49780124 fbshipit-source-id: 48d27757b0e33fd47287fefef0e622595b9e0c09
1 parent deabaca commit 869704e

File tree

6 files changed

+259
-6
lines changed

6 files changed

+259
-6
lines changed

extension/pytree/aten_util/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def define_common_targets():
1717
],
1818
exported_deps = [
1919
"//executorch/extension/pytree:pytree",
20+
"//executorch/runtime/platform:platform",
2021
],
2122
compiler_flags = ["-Wno-missing-prototypes"],
2223
fbcode_deps = [

extension/pytree/function_ref.h

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
10+
//
11+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://llvm.org/LICENSE.txt for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//===----------------------------------------------------------------------===//
16+
//
17+
// This file contains some extension to <functional>.
18+
//
19+
// No library is required when using these functions.
20+
//
21+
//===----------------------------------------------------------------------===//
22+
// Extra additions to <functional>
23+
//===----------------------------------------------------------------------===//
24+
25+
/// An efficient, type-erasing, non-owning reference to a callable. This is
26+
/// intended for use as the type of a function parameter that is not used
27+
/// after the function in question returns.
28+
///
29+
/// This class does not own the callable, so it is not in general safe to store
30+
/// a FunctionRef.
31+
32+
// torch::executor: modified from llvm::function_ref
33+
// see https://www.foonathan.net/2017/01/function-ref-implementation/
34+
35+
#pragma once
36+
37+
#include <cstdint>
38+
#include <type_traits>
39+
#include <utility>
40+
41+
namespace torch {
42+
namespace executor {
43+
namespace pytree {
44+
45+
//===----------------------------------------------------------------------===//
46+
// Features from C++20
47+
//===----------------------------------------------------------------------===//
48+
49+
template <typename T>
50+
struct remove_cvref {
51+
using type =
52+
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
53+
};
54+
55+
template <typename T>
56+
using remove_cvref_t = typename remove_cvref<T>::type;
57+
58+
template <typename Fn>
59+
class FunctionRef;
60+
61+
template <typename Ret, typename... Params>
62+
class FunctionRef<Ret(Params...)> {
63+
Ret (*callback_)(const void* memory, Params... params) = nullptr;
64+
union Storage {
65+
void* callable;
66+
Ret (*function)(Params...);
67+
} storage_;
68+
69+
public:
70+
FunctionRef() = default;
71+
explicit FunctionRef(std::nullptr_t) {}
72+
73+
/**
74+
* Case 1: A callable object passed by lvalue reference.
75+
* Taking rvalue reference is error prone because the object will be always
76+
* be destroyed immediately.
77+
*/
78+
template <
79+
typename Callable,
80+
// This is not the copy-constructor.
81+
typename std::enable_if<
82+
!std::is_same<remove_cvref_t<Callable>, FunctionRef>::value,
83+
int32_t>::type = 0,
84+
// Avoid lvalue reference to non-capturing lambda.
85+
typename std::enable_if<
86+
!std::is_convertible<Callable, Ret (*)(Params...)>::value,
87+
int32_t>::type = 0,
88+
// Functor must be callable and return a suitable type.
89+
// To make this container type safe, we need to ensure either:
90+
// 1. The return type is void.
91+
// 2. Or the resulting type from calling the callable is convertible to
92+
// the declared return type.
93+
typename std::enable_if<
94+
std::is_void<Ret>::value ||
95+
std::is_convertible<
96+
decltype(std::declval<Callable>()(std::declval<Params>()...)),
97+
Ret>::value,
98+
int32_t>::type = 0>
99+
explicit FunctionRef(Callable& callable)
100+
: callback_([](const void* memory, Params... params) {
101+
auto& storage = *static_cast<const Storage*>(memory);
102+
auto& callable = *static_cast<Callable*>(storage.callable);
103+
return static_cast<Ret>(callable(std::forward<Params>(params)...));
104+
}) {
105+
storage_.callable = &callable;
106+
}
107+
108+
/**
109+
* Case 2: A plain function pointer.
110+
* Instead of storing an opaque pointer to underlying callable object,
111+
* store a function pointer directly.
112+
* Note that in the future a variant which coerces compatible function
113+
* pointers could be implemented by erasing the storage type.
114+
*/
115+
/* implicit */ FunctionRef(Ret (*ptr)(Params...))
116+
: callback_([](const void* memory, Params... params) {
117+
auto& storage = *static_cast<const Storage*>(memory);
118+
return storage.function(std::forward<Params>(params)...);
119+
}) {
120+
storage_.function = ptr;
121+
}
122+
123+
/**
124+
* Case 3: Implicit conversion from lambda to FunctionRef.
125+
* A common use pattern is like:
126+
* void foo(FunctionRef<...>) {...}
127+
* foo([](...){...})
128+
* Here constructors for non const lvalue reference or function pointer
129+
* would not work because they do not cover implicit conversion from rvalue
130+
* lambda.
131+
* We need to define a constructor for capturing temporary callables and
132+
* always try to convert the lambda to a function pointer behind the scene.
133+
*/
134+
template <
135+
typename Function,
136+
// This is not the copy-constructor.
137+
typename std::enable_if<
138+
!std::is_same<Function, FunctionRef>::value,
139+
int32_t>::type = 0,
140+
// Function is convertible to pointer of (Params...) -> Ret.
141+
typename std::enable_if<
142+
std::is_convertible<Function, Ret (*)(Params...)>::value,
143+
int32_t>::type = 0>
144+
/* implicit */ FunctionRef(const Function& function)
145+
: FunctionRef(static_cast<Ret (*)(Params...)>(function)) {}
146+
147+
Ret operator()(Params... params) const {
148+
return callback_(&storage_, std::forward<Params>(params)...);
149+
}
150+
151+
explicit operator bool() const {
152+
return callback_;
153+
}
154+
};
155+
156+
} // namespace pytree
157+
} // namespace executor
158+
} // namespace torch

extension/pytree/pytree.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
#include <memory>
1717
#include <string>
1818

19-
#include <executorch/runtime/core/function_ref.h>
19+
// NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime.
20+
#include <executorch/extension/pytree/function_ref.h>
2021

2122
namespace torch {
2223
namespace executor {

extension/pytree/targets.bzl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,9 @@ def define_common_targets():
1010
runtime.cxx_library(
1111
name = "pytree",
1212
srcs = [],
13-
exported_headers = ["pytree.h"],
13+
exported_headers = ["pytree.h", "function_ref.h"],
1414
visibility = [
1515
"//executorch/...",
1616
"@EXECUTORCH_CLIENTS",
1717
],
18-
exported_deps = [
19-
"//executorch/runtime/platform:platform",
20-
"//executorch/runtime/core:core",
21-
],
2218
)

extension/pytree/test/TARGETS

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ cpp_unittest(
1111
deps = ["//executorch/extension/pytree:pytree"],
1212
)
1313

14+
cpp_unittest(
15+
name = "function_ref_test",
16+
srcs = ["function_ref_test.cpp"],
17+
supports_static_listing = True,
18+
deps = ["//executorch/extension/pytree:pytree"],
19+
)
20+
1421
python_unittest(
1522
name = "test",
1623
srcs = [
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <executorch/extension/pytree/function_ref.h>
12+
13+
using namespace ::testing;
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace pytree {
18+
19+
namespace {
20+
class Item {
21+
private:
22+
int32_t val_;
23+
FunctionRef<void(int32_t&)> ref_;
24+
25+
public:
26+
/* implicit */ Item(int32_t val, FunctionRef<void(int32_t&)> ref)
27+
: val_(val), ref_(ref) {}
28+
29+
int32_t get() {
30+
ref_(val_);
31+
return val_;
32+
}
33+
};
34+
35+
void one(int32_t& i) {
36+
i = 1;
37+
}
38+
39+
} // namespace
40+
41+
TEST(FunctionRefTest, CapturingLambda) {
42+
auto one = 1;
43+
auto f = [&](int32_t& i) { i = one; };
44+
Item item(0, FunctionRef<void(int32_t&)>{f});
45+
EXPECT_EQ(item.get(), 1);
46+
// ERROR:
47+
// Item item1(0, f);
48+
// Item item2(0, [&](int32_t& i) { i = 2; });
49+
// FunctionRef<void(int32_t&)> ref([&](int32_t&){});
50+
}
51+
52+
TEST(FunctionRefTest, NonCapturingLambda) {
53+
int32_t val = 0;
54+
FunctionRef<void(int32_t&)> ref([](int32_t& i) { i = 1; });
55+
ref(val);
56+
EXPECT_EQ(val, 1);
57+
58+
val = 0;
59+
auto lambda = [](int32_t& i) { i = 1; };
60+
FunctionRef<void(int32_t&)> ref1(lambda);
61+
ref1(val);
62+
EXPECT_EQ(val, 1);
63+
64+
Item item(0, [](int32_t& i) { i = 1; });
65+
EXPECT_EQ(item.get(), 1);
66+
67+
auto f = [](int32_t& i) { i = 1; };
68+
Item item1(0, f);
69+
EXPECT_EQ(item1.get(), 1);
70+
71+
Item item2(0, std::move(f));
72+
EXPECT_EQ(item2.get(), 1);
73+
}
74+
75+
TEST(FunctionRefTest, FunctionPointer) {
76+
int32_t val = 0;
77+
FunctionRef<void(int32_t&)> ref(one);
78+
ref(val);
79+
EXPECT_EQ(val, 1);
80+
81+
Item item(0, one);
82+
EXPECT_EQ(item.get(), 1);
83+
84+
Item item1(0, &one);
85+
EXPECT_EQ(item1.get(), 1);
86+
}
87+
88+
} // namespace pytree
89+
} // namespace executor
90+
} // namespace torch

0 commit comments

Comments
 (0)