Skip to content

Simplify runtime FunctionRef #555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension/pytree/aten_util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def define_common_targets():
],
exported_deps = [
"//executorch/extension/pytree:pytree",
"//executorch/runtime/platform:platform",
],
compiler_flags = ["-Wno-missing-prototypes"],
fbcode_deps = [
Expand Down
158 changes: 158 additions & 0 deletions extension/pytree/function_ref.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains some extension to <functional>.
//
// No library is required when using these functions.
//
//===----------------------------------------------------------------------===//
// Extra additions to <functional>
//===----------------------------------------------------------------------===//

/// An efficient, type-erasing, non-owning reference to a callable. This is
/// intended for use as the type of a function parameter that is not used
/// after the function in question returns.
///
/// This class does not own the callable, so it is not in general safe to store
/// a FunctionRef.

// torch::executor: modified from llvm::function_ref
// see https://www.foonathan.net/2017/01/function-ref-implementation/

#pragma once

#include <cstdint>
#include <type_traits>
#include <utility>

namespace torch {
namespace executor {
namespace pytree {

//===----------------------------------------------------------------------===//
// Features from C++20
//===----------------------------------------------------------------------===//

template <typename T>
struct remove_cvref {
using type =
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
};

template <typename T>
using remove_cvref_t = typename remove_cvref<T>::type;

template <typename Fn>
class FunctionRef;

template <typename Ret, typename... Params>
class FunctionRef<Ret(Params...)> {
Ret (*callback_)(const void* memory, Params... params) = nullptr;
union Storage {
void* callable;
Ret (*function)(Params...);
} storage_;

public:
FunctionRef() = default;
explicit FunctionRef(std::nullptr_t) {}

/**
* Case 1: A callable object passed by lvalue reference.
* Taking rvalue reference is error prone because the object will be always
* be destroyed immediately.
*/
template <
typename Callable,
// This is not the copy-constructor.
typename std::enable_if<
!std::is_same<remove_cvref_t<Callable>, FunctionRef>::value,
int32_t>::type = 0,
// Avoid lvalue reference to non-capturing lambda.
typename std::enable_if<
!std::is_convertible<Callable, Ret (*)(Params...)>::value,
int32_t>::type = 0,
// Functor must be callable and return a suitable type.
// To make this container type safe, we need to ensure either:
// 1. The return type is void.
// 2. Or the resulting type from calling the callable is convertible to
// the declared return type.
typename std::enable_if<
std::is_void<Ret>::value ||
std::is_convertible<
decltype(std::declval<Callable>()(std::declval<Params>()...)),
Ret>::value,
int32_t>::type = 0>
explicit FunctionRef(Callable& callable)
: callback_([](const void* memory, Params... params) {
auto& storage = *static_cast<const Storage*>(memory);
auto& callable = *static_cast<Callable*>(storage.callable);
return static_cast<Ret>(callable(std::forward<Params>(params)...));
}) {
storage_.callable = &callable;
}

/**
* Case 2: A plain function pointer.
* Instead of storing an opaque pointer to underlying callable object,
* store a function pointer directly.
* Note that in the future a variant which coerces compatible function
* pointers could be implemented by erasing the storage type.
*/
/* implicit */ FunctionRef(Ret (*ptr)(Params...))
: callback_([](const void* memory, Params... params) {
auto& storage = *static_cast<const Storage*>(memory);
return storage.function(std::forward<Params>(params)...);
}) {
storage_.function = ptr;
}

/**
* Case 3: Implicit conversion from lambda to FunctionRef.
* A common use pattern is like:
* void foo(FunctionRef<...>) {...}
* foo([](...){...})
* Here constructors for non const lvalue reference or function pointer
* would not work because they do not cover implicit conversion from rvalue
* lambda.
* We need to define a constructor for capturing temporary callables and
* always try to convert the lambda to a function pointer behind the scene.
*/
template <
typename Function,
// This is not the copy-constructor.
typename std::enable_if<
!std::is_same<Function, FunctionRef>::value,
int32_t>::type = 0,
// Function is convertible to pointer of (Params...) -> Ret.
typename std::enable_if<
std::is_convertible<Function, Ret (*)(Params...)>::value,
int32_t>::type = 0>
/* implicit */ FunctionRef(const Function& function)
: FunctionRef(static_cast<Ret (*)(Params...)>(function)) {}

Ret operator()(Params... params) const {
return callback_(&storage_, std::forward<Params>(params)...);
}

explicit operator bool() const {
return callback_;
}
};

} // namespace pytree
} // namespace executor
} // namespace torch
3 changes: 2 additions & 1 deletion extension/pytree/pytree.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
#include <memory>
#include <string>

#include <executorch/runtime/core/function_ref.h>
// NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime.
#include <executorch/extension/pytree/function_ref.h>

namespace torch {
namespace executor {
Expand Down
6 changes: 1 addition & 5 deletions extension/pytree/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@ def define_common_targets():
runtime.cxx_library(
name = "pytree",
srcs = [],
exported_headers = ["pytree.h"],
exported_headers = ["pytree.h", "function_ref.h"],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
"//executorch/runtime/platform:platform",
"//executorch/runtime/core:core",
],
)
7 changes: 7 additions & 0 deletions extension/pytree/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ cpp_unittest(
deps = ["//executorch/extension/pytree:pytree"],
)

cpp_unittest(
name = "function_ref_test",
srcs = ["function_ref_test.cpp"],
supports_static_listing = True,
deps = ["//executorch/extension/pytree:pytree"],
)

python_unittest(
name = "test",
srcs = [
Expand Down
90 changes: 90 additions & 0 deletions extension/pytree/test/function_ref_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <gtest/gtest.h>

#include <executorch/extension/pytree/function_ref.h>

using namespace ::testing;

namespace torch {
namespace executor {
namespace pytree {

namespace {
class Item {
private:
int32_t val_;
FunctionRef<void(int32_t&)> ref_;

public:
/* implicit */ Item(int32_t val, FunctionRef<void(int32_t&)> ref)
: val_(val), ref_(ref) {}

int32_t get() {
ref_(val_);
return val_;
}
};

void one(int32_t& i) {
i = 1;
}

} // namespace

TEST(FunctionRefTest, CapturingLambda) {
auto one = 1;
auto f = [&](int32_t& i) { i = one; };
Item item(0, FunctionRef<void(int32_t&)>{f});
EXPECT_EQ(item.get(), 1);
// ERROR:
// Item item1(0, f);
// Item item2(0, [&](int32_t& i) { i = 2; });
// FunctionRef<void(int32_t&)> ref([&](int32_t&){});
}

TEST(FunctionRefTest, NonCapturingLambda) {
int32_t val = 0;
FunctionRef<void(int32_t&)> ref([](int32_t& i) { i = 1; });
ref(val);
EXPECT_EQ(val, 1);

val = 0;
auto lambda = [](int32_t& i) { i = 1; };
FunctionRef<void(int32_t&)> ref1(lambda);
ref1(val);
EXPECT_EQ(val, 1);

Item item(0, [](int32_t& i) { i = 1; });
EXPECT_EQ(item.get(), 1);

auto f = [](int32_t& i) { i = 1; };
Item item1(0, f);
EXPECT_EQ(item1.get(), 1);

Item item2(0, std::move(f));
EXPECT_EQ(item2.get(), 1);
}

TEST(FunctionRefTest, FunctionPointer) {
int32_t val = 0;
FunctionRef<void(int32_t&)> ref(one);
ref(val);
EXPECT_EQ(val, 1);

Item item(0, one);
EXPECT_EQ(item.get(), 1);

Item item1(0, &one);
EXPECT_EQ(item1.get(), 1);
}

} // namespace pytree
} // namespace executor
} // namespace torch
51 changes: 5 additions & 46 deletions runtime/core/function_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ class FunctionRef;

template <typename Ret, typename... Params>
class FunctionRef<Ret(Params...)> {
Ret (*callback_)(const void* memory, Params... params) = nullptr;
union Storage {
void* callable;
Ret (*function)(Params...);
} storage_;

Expand All @@ -70,57 +68,18 @@ class FunctionRef<Ret(Params...)> {
explicit FunctionRef(std::nullptr_t) {}

/**
* Case 1: A callable object passed by lvalue reference.
* Taking rvalue reference is error prone because the object will be always
* be destroyed immediately.
*/
template <
typename Callable,
// This is not the copy-constructor.
typename std::enable_if<
!std::is_same<remove_cvref_t<Callable>, FunctionRef>::value,
int32_t>::type = 0,
// Avoid lvalue reference to non-capturing lambda.
typename std::enable_if<
!std::is_convertible<Callable, Ret (*)(Params...)>::value,
int32_t>::type = 0,
// Functor must be callable and return a suitable type.
// To make this container type safe, we need to ensure either:
// 1. The return type is void.
// 2. Or the resulting type from calling the callable is convertible to
// the declared return type.
typename std::enable_if<
std::is_void<Ret>::value ||
std::is_convertible<
decltype(std::declval<Callable>()(std::declval<Params>()...)),
Ret>::value,
int32_t>::type = 0>
explicit FunctionRef(Callable& callable)
: callback_([](const void* memory, Params... params) {
auto& storage = *static_cast<const Storage*>(memory);
auto& callable = *static_cast<Callable*>(storage.callable);
return static_cast<Ret>(callable(std::forward<Params>(params)...));
}) {
storage_.callable = &callable;
}

/**
* Case 2: A plain function pointer.
* Case 1: A plain function pointer.
* Instead of storing an opaque pointer to underlying callable object,
* store a function pointer directly.
* Note that in the future a variant which coerces compatible function
* pointers could be implemented by erasing the storage type.
*/
/* implicit */ FunctionRef(Ret (*ptr)(Params...))
: callback_([](const void* memory, Params... params) {
auto& storage = *static_cast<const Storage*>(memory);
return storage.function(std::forward<Params>(params)...);
}) {
/* implicit */ FunctionRef(Ret (*ptr)(Params...)) {
storage_.function = ptr;
}

/**
* Case 3: Implicit conversion from lambda to FunctionRef.
* Case 2: Implicit conversion from lambda to FunctionRef.
* A common use pattern is like:
* void foo(FunctionRef<...>) {...}
* foo([](...){...})
Expand All @@ -144,11 +103,11 @@ class FunctionRef<Ret(Params...)> {
: FunctionRef(static_cast<Ret (*)(Params...)>(function)) {}

Ret operator()(Params... params) const {
return callback_(&storage_, std::forward<Params>(params)...);
return storage_.function(std::forward<Params>(params)...);
}

explicit operator bool() const {
return callback_;
return storage_.function;
}
};

Expand Down
Loading