Skip to content

Commit d6b800b

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Add helper function to create empty, full, ones and zeros tensors. (#5261)
Summary: Pull Request resolved: #5261 . Reviewed By: kirklandsign Differential Revision: D62486240 fbshipit-source-id: 1c89db9ed2b31d85ffa68320348f00bc297686f8
1 parent 4da3c5d commit d6b800b

File tree

6 files changed

+513
-5
lines changed

6 files changed

+513
-5
lines changed

extension/tensor/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def define_common_targets():
1515
srcs = [
1616
"tensor_impl_ptr.cpp",
1717
"tensor_ptr.cpp",
18+
"tensor_ptr_maker.cpp",
1819
],
1920
exported_headers = [
2021
"tensor.h",

extension/tensor/tensor_ptr.h

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ inline TensorPtr make_tensor_ptr(
142142
*
143143
* This template overload is specialized for cases where the tensor data is
144144
* provided as a vector. The scalar type is automatically deduced from the
145-
* vector's data type. The deleter ensures that the data vector is properly
146-
* managed and its lifetime is tied to the TensorImpl.
145+
* vector's data type.
147146
*
148147
* @tparam T The C++ type of the tensor elements, deduced from the vector.
149148
* @param sizes A vector specifying the size of each dimension.
@@ -174,8 +173,7 @@ TensorPtr make_tensor_ptr(
174173
*
175174
* This template overload is specialized for cases where the tensor data is
176175
* provided as a vector. The scalar type is automatically deduced from the
177-
* vector's data type. The deleter ensures that the data vector is properly
178-
* managed and its lifetime is tied to the TensorImpl.
176+
* vector's data type.
179177
*
180178
* @tparam T The C++ type of the tensor elements, deduced from the vector.
181179
* @param data A vector containing the tensor's data.
@@ -190,6 +188,27 @@ TensorPtr make_tensor_ptr(
190188
return make_tensor_ptr(make_tensor_impl_ptr(std::move(data), dynamism));
191189
}
192190

191+
/**
192+
* Creates a TensorPtr that manages a Tensor with the specified properties.
193+
*
194+
* This template overload allows creating a Tensor from an initializer list
195+
* of data. The scalar type is automatically deduced from the type of the
196+
* initializer list's elements.
197+
*
198+
* @tparam T The C++ type of the tensor elements, deduced from the initializer
199+
* list.
200+
* @param data An initializer list containing the tensor's data.
201+
* @param dynamism Specifies the mutability of the tensor's shape.
202+
* @return A TensorPtr that manages the newly created TensorImpl.
203+
*/
204+
template <typename T = float>
205+
TensorPtr make_tensor_ptr(
206+
std::initializer_list<T> data,
207+
exec_aten::TensorShapeDynamism dynamism =
208+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
209+
return make_tensor_ptr(std::vector<T>(data), dynamism);
210+
}
211+
193212
/**
194213
* Creates a TensorPtr that manages a Tensor with the specified properties.
195214
*

extension/tensor/tensor_ptr_maker.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/tensor/tensor_ptr_maker.h>
10+
11+
namespace executorch {
12+
namespace extension {
13+
namespace {
14+
template <
15+
typename INT_T,
16+
typename std::enable_if<
17+
std::is_integral<INT_T>::value && !std::is_same<INT_T, bool>::value,
18+
bool>::type = true>
19+
bool extract_scalar(exec_aten::Scalar scalar, INT_T* out_val) {
20+
if (!scalar.isIntegral(/*includeBool=*/false)) {
21+
return false;
22+
}
23+
int64_t val = scalar.to<int64_t>();
24+
if (val < std::numeric_limits<INT_T>::lowest() ||
25+
val > std::numeric_limits<INT_T>::max()) {
26+
return false;
27+
}
28+
*out_val = static_cast<INT_T>(val);
29+
return true;
30+
}
31+
32+
template <
33+
typename FLOAT_T,
34+
typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
35+
type = true>
36+
bool extract_scalar(exec_aten::Scalar scalar, FLOAT_T* out_val) {
37+
double val;
38+
if (scalar.isFloatingPoint()) {
39+
val = scalar.to<double>();
40+
if (std::isfinite(val) &&
41+
(val < std::numeric_limits<FLOAT_T>::lowest() ||
42+
val > std::numeric_limits<FLOAT_T>::max())) {
43+
return false;
44+
}
45+
} else if (scalar.isIntegral(/*includeBool=*/false)) {
46+
val = static_cast<double>(scalar.to<int64_t>());
47+
} else {
48+
return false;
49+
}
50+
*out_val = static_cast<FLOAT_T>(val);
51+
return true;
52+
}
53+
54+
template <
55+
typename BOOL_T,
56+
typename std::enable_if<std::is_same<BOOL_T, bool>::value, bool>::type =
57+
true>
58+
bool extract_scalar(exec_aten::Scalar scalar, BOOL_T* out_val) {
59+
if (scalar.isIntegral(false)) {
60+
*out_val = static_cast<bool>(scalar.to<int64_t>());
61+
return true;
62+
}
63+
if (scalar.isBoolean()) {
64+
*out_val = scalar.to<bool>();
65+
return true;
66+
}
67+
return false;
68+
}
69+
70+
#define ET_EXTRACT_SCALAR(scalar, out_val) \
71+
ET_CHECK_MSG( \
72+
extract_scalar(scalar, &out_val), \
73+
#scalar " could not be extracted: wrong type or out of range");
74+
75+
} // namespace
76+
77+
TensorPtr empty_strided(
78+
std::vector<exec_aten::SizesType> sizes,
79+
std::vector<exec_aten::StridesType> strides,
80+
exec_aten::ScalarType type,
81+
exec_aten::TensorShapeDynamism dynamism) {
82+
std::vector<uint8_t> data(
83+
exec_aten::compute_numel(sizes.data(), sizes.size()) *
84+
exec_aten::elementSize(type));
85+
return make_tensor_ptr(
86+
type,
87+
std::move(sizes),
88+
std::move(data),
89+
{},
90+
std::move(strides),
91+
dynamism);
92+
}
93+
94+
TensorPtr full_strided(
95+
std::vector<exec_aten::SizesType> sizes,
96+
std::vector<exec_aten::StridesType> strides,
97+
exec_aten::Scalar fill_value,
98+
exec_aten::ScalarType type,
99+
exec_aten::TensorShapeDynamism dynamism) {
100+
auto tensor =
101+
empty_strided(std::move(sizes), std::move(strides), type, dynamism);
102+
ET_SWITCH_REALB_TYPES(type, nullptr, "full_strided", CTYPE, [&] {
103+
CTYPE value;
104+
ET_EXTRACT_SCALAR(fill_value, value);
105+
std::fill(
106+
tensor->mutable_data_ptr<CTYPE>(),
107+
tensor->mutable_data_ptr<CTYPE>() + tensor->numel(),
108+
value);
109+
});
110+
return tensor;
111+
}
112+
113+
} // namespace extension
114+
} // namespace executorch

0 commit comments

Comments
 (0)