Skip to content

Commit cb944b7

Browse files
authored
Expose the compute number of elements helper function.
Differential Revision: D62352386 Pull Request resolved: #5166
1 parent cb71193 commit cb944b7

File tree

4 files changed

+34
-4
lines changed

4 files changed

+34
-4
lines changed

runtime/core/exec_aten/exec_aten.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ using IntArrayRef = at::IntArrayRef;
8787
template <typename T>
8888
using OptionalArrayRef = c10::OptionalArrayRef<T>;
8989

90+
inline ssize_t compute_numel(const SizesType* sizes, ssize_t dim) {
91+
return static_cast<ssize_t>(
92+
c10::multiply_integers(c10::ArrayRef<SizesType>(sizes, dim)));
93+
}
94+
9095
#else // Use executor types
9196

9297
using Tensor = torch::executor::Tensor;
@@ -127,9 +132,12 @@ template <typename T>
127132
using OptionalArrayRef =
128133
torch::executor::optional<torch::executor::ArrayRef<T>>;
129134

135+
using torch::executor::compute_numel;
136+
130137
#endif // Use executor types
131138

132139
} // namespace exec_aten
140+
133141
namespace torch {
134142
namespace executor {
135143
using TensorList = exec_aten::TensorList;

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <type_traits>
2929

3030
#include <executorch/runtime/platform/assert.h>
31+
3132
#ifdef USE_ATEN_LIB
3233
// Note that a lot of the macros/functions defined in this ScalarTypeUtil.h file
3334
// are also defined in c10/core/ScalarType.h, which is included via
@@ -39,14 +40,14 @@
3940
namespace exec_aten {
4041
using ScalarType = at::ScalarType;
4142
}
42-
#else
43+
#else // !USE_ATEN_LIB
4344
#include <executorch/runtime/core/portable_type/scalar_type.h>
4445
#include <executorch/runtime/core/portable_type/string_view.h>
4546
namespace exec_aten {
4647
using ScalarType = torch::executor::ScalarType;
4748
using string_view = torch::executor::string_view;
4849
} // namespace exec_aten
49-
#endif
50+
#endif // USE_ATEN_LIB
5051

5152
namespace executorch {
5253
namespace runtime {
@@ -1361,6 +1362,14 @@ inline exec_aten::ScalarType promoteTypes(
13611362
} // namespace runtime
13621363
} // namespace executorch
13631364

1365+
namespace exec_aten {
1366+
#ifdef USE_ATEN_LIB
1367+
using ::at::elementSize;
1368+
#else // USE_ATEN_LIB
1369+
using ::executorch::runtime::elementSize;
1370+
#endif // USE_ATEN_LIB
1371+
} // namespace exec_aten
1372+
13641373
namespace torch {
13651374
namespace executor {
13661375
// TODO(T197294990): Remove these deprecated aliases once all users have moved

runtime/core/portable_type/tensor_impl.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
namespace torch {
2121
namespace executor {
2222

23-
namespace {
2423
/**
2524
* Compute the number of elements based on the sizes of a tensor.
2625
*/
@@ -39,7 +38,6 @@ ssize_t compute_numel(const TensorImpl::SizesType* sizes, ssize_t dim) {
3938
}
4039
return numel;
4140
}
42-
} // namespace
4341

4442
TensorImpl::TensorImpl(
4543
ScalarType type,

runtime/core/portable_type/tensor_impl.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,5 +253,20 @@ class TensorImpl {
253253
const TensorShapeDynamism shape_dynamism_;
254254
};
255255

256+
/**
257+
* Compute the number of elements based on the sizes of a tensor.
258+
*/
259+
ssize_t compute_numel(
260+
const ::torch::executor::TensorImpl::SizesType* sizes,
261+
ssize_t dim);
262+
256263
} // namespace executor
257264
} // namespace torch
265+
266+
namespace executorch {
267+
namespace runtime {
268+
// TODO(T197294990): Remove these deprecated aliases once all users have moved
269+
// to the new `::executorch` namespaces.
270+
using torch::executor::compute_numel;
271+
} // namespace runtime
272+
} // namespace executorch

0 commit comments

Comments
 (0)