Skip to content

Commit 2170a84

Browse files
authored
TensorLayout updates
Differential Revision: D68535453 Pull Request resolved: #7870
1 parent 270271b commit 2170a84

File tree

5 files changed

+148
-45
lines changed

5 files changed

+148
-45
lines changed

runtime/core/targets.bzl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def define_common_targets():
4444
"named_data_map.h",
4545
"result.h",
4646
"span.h",
47-
"tensor_layout.h",
4847
],
4948
visibility = [
5049
"//executorch/...",
@@ -133,3 +132,14 @@ def define_common_targets():
133132
"//executorch/...",
134133
],
135134
)
135+
136+
runtime.cxx_library(
137+
name = "tensor_layout",
138+
srcs = ["tensor_layout.cpp"],
139+
exported_headers = ["tensor_layout.h"],
140+
exported_deps = [
141+
":core",
142+
"//executorch/runtime/core/exec_aten:lib",
143+
],
144+
visibility = ["//executorch/..."],
145+
)

runtime/core/tensor_layout.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
10+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
11+
#include <executorch/runtime/core/span.h>
12+
#include <executorch/runtime/core/tensor_layout.h>
13+
14+
namespace executorch {
15+
namespace runtime {
16+
17+
namespace {
18+
Result<size_t> calculate_nbytes(
19+
const Span<const int32_t>& sizes,
20+
const exec_aten::ScalarType& scalar_type) {
21+
ssize_t n = 1;
22+
for (ssize_t i = 0; i < sizes.size(); i++) {
23+
if (sizes[i] < 0) {
24+
return Error::InvalidArgument;
25+
}
26+
n *= sizes[i];
27+
}
28+
// Use the full namespace to disambiguate from c10::elementSize.
29+
return n * executorch::runtime::elementSize(scalar_type);
30+
}
31+
} // namespace
32+
33+
Result<TensorLayout> TensorLayout::create(
34+
Span<const int32_t> sizes,
35+
Span<const uint8_t> dim_order,
36+
executorch::aten::ScalarType scalar_type) {
37+
auto nbytes = calculate_nbytes(sizes, scalar_type);
38+
if (!nbytes.ok()) {
39+
return nbytes.error();
40+
}
41+
42+
if (dim_order.size() != sizes.size()) {
43+
return Error::InvalidArgument;
44+
}
45+
46+
for (size_t i = 0; i < dim_order.size(); i++) {
47+
if (dim_order[i] >= sizes.size()) {
48+
return Error::InvalidArgument;
49+
}
50+
}
51+
return TensorLayout(sizes, dim_order, scalar_type, nbytes.get());
52+
}
53+
} // namespace runtime
54+
} // namespace executorch

runtime/core/tensor_layout.h

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,55 +10,48 @@
1010

1111
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1212
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
13+
#include <executorch/runtime/core/result.h>
1314
#include <executorch/runtime/core/span.h>
1415

1516
namespace executorch {
1617
namespace runtime {
1718

18-
namespace {
19-
size_t calculate_nbytes(
20-
const Span<const int32_t>& sizes,
21-
const exec_aten::ScalarType& scalar_type) {
22-
ssize_t n = 1;
23-
for (ssize_t i = 0; i < sizes.size(); i++) {
24-
ET_CHECK(sizes[i] >= 0);
25-
n *= sizes[i];
26-
}
27-
// Use the full namespace to disambiguate from c10::elementSize.
28-
return n * executorch::runtime::elementSize(scalar_type);
29-
}
30-
} // namespace
31-
3219
/**
33-
* Metadata describing the layout of external tensors (tensors that are not
34-
stored in the PTE file).
35-
*
36-
* The NamedDataMap used to create the TensorLayout must outlive the
37-
TensorLayout.
20+
* Describes the layout of a tensor.
3821
*/
39-
class TensorLayout {
22+
class ET_EXPERIMENTAL TensorLayout final {
4023
public:
41-
TensorLayout(
42-
executorch::aten::ScalarType scalar_type,
43-
Span<const int32_t> sizes,
44-
Span<const uint8_t> dim_order)
45-
: sizes_(sizes),
46-
dim_order_(dim_order),
47-
scalar_type_(scalar_type),
48-
nbytes_(calculate_nbytes(sizes_, scalar_type_)) {}
24+
TensorLayout() = delete;
4925

50-
TensorLayout(const TensorLayout&) = default;
51-
TensorLayout(TensorLayout&&) = default;
52-
TensorLayout& operator=(const TensorLayout&) = default;
53-
TensorLayout& operator=(TensorLayout&& other) = default;
54-
~TensorLayout() = default;
26+
/**
27+
* Creates a TensorLayout from the given parameters.
28+
*
29+
* @param[in] sizes The sizes of the tensor. Note: the span passed here must
30+
* outlive the TensorLayout and all copies of it.
31+
* @param[in] dim_order The dim order of the tensor. Note: the span passed
32+
* here must outlive the TensorLayout and all copies of it.
33+
* @param[in] scalar_type The scalar type of the tensor.
34+
* @return A Result containing the TensorLayout on success, or an error.
35+
*/
36+
static executorch::runtime::Result<TensorLayout> create(
37+
Span<const int32_t> sizes,
38+
Span<const uint8_t> dim_order,
39+
executorch::aten::ScalarType scalar_type);
5540

56-
/// Returns the sizes of the tensor.
41+
/**
42+
* Returns the sizes of the tensor.
43+
*
44+
* NOTE: The TensorLayout must outlive the spans returned here.
45+
*/
5746
Span<const int32_t> sizes() const {
5847
return sizes_;
5948
}
6049

61-
/// Returns the dim order of the tensor.
50+
/**
51+
* Returns the dim order of the tensor.
52+
*
53+
* NOTE: The TensorLayout must outlive the spans returned here.
54+
*/
6255
Span<const uint8_t> dim_order() const {
6356
return dim_order_;
6457
}
@@ -74,6 +67,15 @@ class TensorLayout {
7467
}
7568

7669
private:
70+
TensorLayout(
71+
Span<const int32_t> sizes,
72+
Span<const uint8_t> dim_order,
73+
executorch::aten::ScalarType scalar_type,
74+
size_t nbytes)
75+
: sizes_(sizes),
76+
dim_order_(dim_order),
77+
scalar_type_(scalar_type),
78+
nbytes_(nbytes) {}
7779
/// The sizes of the tensor.
7880
Span<const int32_t> sizes_;
7981

runtime/core/test/targets.bzl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ def define_common_targets():
1919
name = "tensor_layout_test",
2020
srcs = ["tensor_layout_test.cpp"],
2121
deps = [
22-
"//executorch/runtime/core:core",
23-
"//executorch/runtime/core/exec_aten:lib",
22+
"//executorch/runtime/core:tensor_layout",
2423
],
2524
)
2625

runtime/core/test/tensor_layout_test.cpp

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,31 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/runtime/core/error.h>
910
#include <executorch/runtime/core/exec_aten/exec_aten.h>
11+
#include <executorch/runtime/core/result.h>
1012
#include <executorch/runtime/core/tensor_layout.h>
1113

1214
#include <gtest/gtest.h>
1315

1416
using namespace ::testing;
1517
using executorch::aten::ScalarType;
18+
using executorch::runtime::Error;
19+
using executorch::runtime::Result;
1620
using executorch::runtime::Span;
1721
using executorch::runtime::TensorLayout;
1822

1923
TEST(TestTensorLayout, Ctor) {
20-
int32_t sizes[2] = {1, 2};
21-
uint8_t dim_order[2] = {0, 1};
24+
std::array<int32_t, 2> sizes = {1, 2};
25+
std::array<uint8_t, 2> dim_order = {0, 1};
26+
Span<const int32_t> sizes_span = {sizes.data(), sizes.size()};
27+
Span<const uint8_t> dim_order_span = {dim_order.data(), dim_order.size()};
2228

23-
Span<const int32_t> sizes_span = {sizes, sizes + 2};
24-
Span<const uint8_t> dim_order_span = {dim_order, dim_order + 2};
25-
26-
TensorLayout layout =
27-
TensorLayout(ScalarType::Float, sizes_span, dim_order_span);
29+
Result<TensorLayout> layout_res =
30+
TensorLayout::create(sizes_span, dim_order_span, ScalarType::Float);
31+
EXPECT_TRUE(layout_res.ok());
2832

33+
TensorLayout layout = layout_res.get();
2934
EXPECT_EQ(layout.scalar_type(), ScalarType::Float);
3035

3136
EXPECT_EQ(layout.sizes().size(), sizes_span.size());
@@ -38,3 +43,36 @@ TEST(TestTensorLayout, Ctor) {
3843

3944
EXPECT_EQ(layout.nbytes(), 8);
4045
}
46+
47+
TEST(TestTensorLayout, Ctor_InvalidDimOrder) {
48+
std::array<int32_t, 1> sizes = {2};
49+
std::array<uint8_t, 1> dim_order = {1};
50+
Span<const int32_t> sizes_span = {sizes.data(), sizes.size()};
51+
Span<const uint8_t> dim_order_span = {dim_order.data(), dim_order.size()};
52+
53+
Result<TensorLayout> layout_res =
54+
TensorLayout::create(sizes_span, dim_order_span, ScalarType::Float);
55+
EXPECT_EQ(layout_res.error(), Error::InvalidArgument);
56+
}
57+
58+
TEST(TestTensorLayout, Ctor_InvalidSizes) {
59+
std::array<int32_t, 1> sizes = {-1};
60+
std::array<uint8_t, 1> dim_order = {0};
61+
Span<const int32_t> sizes_span = {sizes.data(), sizes.size()};
62+
Span<const uint8_t> dim_order_span = {dim_order.data(), dim_order.size()};
63+
64+
Result<TensorLayout> layout_res =
65+
TensorLayout::create(sizes_span, dim_order_span, ScalarType::Float);
66+
EXPECT_EQ(layout_res.error(), Error::InvalidArgument);
67+
}
68+
69+
TEST(TestTensorLayout, Ctor_SizesDimOrderMismatch) {
70+
std::array<int32_t, 1> sizes = {2};
71+
std::array<uint8_t, 2> dim_order = {0, 1};
72+
Span<const int32_t> sizes_span = {sizes.data(), sizes.size()};
73+
Span<const uint8_t> dim_order_span = {dim_order.data(), dim_order.size()};
74+
75+
Result<TensorLayout> layout_res =
76+
TensorLayout::create(sizes_span, dim_order_span, ScalarType::Float);
77+
EXPECT_EQ(layout_res.error(), Error::InvalidArgument);
78+
}

0 commit comments

Comments
 (0)