|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===// |
| 10 | +// |
| 11 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 12 | +// See https://llvm.org/LICENSE.txt for license information. |
| 13 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 14 | +// |
| 15 | +//===----------------------------------------------------------------------===// |
| 16 | +// |
| 17 | +// This file contains some extension to <functional>. |
| 18 | +// |
| 19 | +// No library is required when using these functions. |
| 20 | +// |
| 21 | +//===----------------------------------------------------------------------===// |
| 22 | +// Extra additions to <functional> |
| 23 | +//===----------------------------------------------------------------------===// |
| 24 | + |
| 25 | +/// An efficient, type-erasing, non-owning reference to a callable. This is |
| 26 | +/// intended for use as the type of a function parameter that is not used |
| 27 | +/// after the function in question returns. |
| 28 | +/// |
| 29 | +/// This class does not own the callable, so it is not in general safe to store |
| 30 | +/// a FunctionRef. |
| 31 | + |
| 32 | +// torch::executor: modified from llvm::function_ref |
| 33 | +// see https://www.foonathan.net/2017/01/function-ref-implementation/ |
| 34 | + |
| 35 | +#pragma once |
| 36 | + |
| 37 | +#include <cstdint> |
| 38 | +#include <type_traits> |
| 39 | +#include <utility> |
| 40 | + |
| 41 | +namespace torch { |
| 42 | +namespace executor { |
| 43 | +namespace pytree { |
| 44 | + |
| 45 | +//===----------------------------------------------------------------------===// |
| 46 | +// Features from C++20 |
| 47 | +//===----------------------------------------------------------------------===// |
| 48 | + |
| 49 | +template <typename T> |
| 50 | +struct remove_cvref { |
| 51 | + using type = |
| 52 | + typename std::remove_cv<typename std::remove_reference<T>::type>::type; |
| 53 | +}; |
| 54 | + |
| 55 | +template <typename T> |
| 56 | +using remove_cvref_t = typename remove_cvref<T>::type; |
| 57 | + |
| 58 | +template <typename Fn> |
| 59 | +class FunctionRef; |
| 60 | + |
| 61 | +template <typename Ret, typename... Params> |
| 62 | +class FunctionRef<Ret(Params...)> { |
| 63 | + Ret (*callback_)(const void* memory, Params... params) = nullptr; |
| 64 | + union Storage { |
| 65 | + void* callable; |
| 66 | + Ret (*function)(Params...); |
| 67 | + } storage_; |
| 68 | + |
| 69 | + public: |
| 70 | + FunctionRef() = default; |
| 71 | + explicit FunctionRef(std::nullptr_t) {} |
| 72 | + |
| 73 | + /** |
| 74 | + * Case 1: A callable object passed by lvalue reference. |
| 75 | + * Taking rvalue reference is error prone because the object will be always |
| 76 | + * be destroyed immediately. |
| 77 | + */ |
| 78 | + template < |
| 79 | + typename Callable, |
| 80 | + // This is not the copy-constructor. |
| 81 | + typename std::enable_if< |
| 82 | + !std::is_same<remove_cvref_t<Callable>, FunctionRef>::value, |
| 83 | + int32_t>::type = 0, |
| 84 | + // Avoid lvalue reference to non-capturing lambda. |
| 85 | + typename std::enable_if< |
| 86 | + !std::is_convertible<Callable, Ret (*)(Params...)>::value, |
| 87 | + int32_t>::type = 0, |
| 88 | + // Functor must be callable and return a suitable type. |
| 89 | + // To make this container type safe, we need to ensure either: |
| 90 | + // 1. The return type is void. |
| 91 | + // 2. Or the resulting type from calling the callable is convertible to |
| 92 | + // the declared return type. |
| 93 | + typename std::enable_if< |
| 94 | + std::is_void<Ret>::value || |
| 95 | + std::is_convertible< |
| 96 | + decltype(std::declval<Callable>()(std::declval<Params>()...)), |
| 97 | + Ret>::value, |
| 98 | + int32_t>::type = 0> |
| 99 | + explicit FunctionRef(Callable& callable) |
| 100 | + : callback_([](const void* memory, Params... params) { |
| 101 | + auto& storage = *static_cast<const Storage*>(memory); |
| 102 | + auto& callable = *static_cast<Callable*>(storage.callable); |
| 103 | + return static_cast<Ret>(callable(std::forward<Params>(params)...)); |
| 104 | + }) { |
| 105 | + storage_.callable = &callable; |
| 106 | + } |
| 107 | + |
| 108 | + /** |
| 109 | + * Case 2: A plain function pointer. |
| 110 | + * Instead of storing an opaque pointer to underlying callable object, |
| 111 | + * store a function pointer directly. |
| 112 | + * Note that in the future a variant which coerces compatible function |
| 113 | + * pointers could be implemented by erasing the storage type. |
| 114 | + */ |
| 115 | + /* implicit */ FunctionRef(Ret (*ptr)(Params...)) |
| 116 | + : callback_([](const void* memory, Params... params) { |
| 117 | + auto& storage = *static_cast<const Storage*>(memory); |
| 118 | + return storage.function(std::forward<Params>(params)...); |
| 119 | + }) { |
| 120 | + storage_.function = ptr; |
| 121 | + } |
| 122 | + |
| 123 | + /** |
| 124 | + * Case 3: Implicit conversion from lambda to FunctionRef. |
| 125 | + * A common use pattern is like: |
| 126 | + * void foo(FunctionRef<...>) {...} |
| 127 | + * foo([](...){...}) |
| 128 | + * Here constructors for non const lvalue reference or function pointer |
| 129 | + * would not work because they do not cover implicit conversion from rvalue |
| 130 | + * lambda. |
| 131 | + * We need to define a constructor for capturing temporary callables and |
| 132 | + * always try to convert the lambda to a function pointer behind the scene. |
| 133 | + */ |
| 134 | + template < |
| 135 | + typename Function, |
| 136 | + // This is not the copy-constructor. |
| 137 | + typename std::enable_if< |
| 138 | + !std::is_same<Function, FunctionRef>::value, |
| 139 | + int32_t>::type = 0, |
| 140 | + // Function is convertible to pointer of (Params...) -> Ret. |
| 141 | + typename std::enable_if< |
| 142 | + std::is_convertible<Function, Ret (*)(Params...)>::value, |
| 143 | + int32_t>::type = 0> |
| 144 | + /* implicit */ FunctionRef(const Function& function) |
| 145 | + : FunctionRef(static_cast<Ret (*)(Params...)>(function)) {} |
| 146 | + |
| 147 | + Ret operator()(Params... params) const { |
| 148 | + return callback_(&storage_, std::forward<Params>(params)...); |
| 149 | + } |
| 150 | + |
| 151 | + explicit operator bool() const { |
| 152 | + return callback_; |
| 153 | + } |
| 154 | +}; |
| 155 | + |
| 156 | +} // namespace pytree |
| 157 | +} // namespace executor |
| 158 | +} // namespace torch |
0 commit comments