Skip to content

Commit fc42a4e

Browse files
authored
Introduce torch::executor::TensorAccessor
Differential Revision: D66033489 Pull Request resolved: #6905
1 parent 97b58bb commit fc42a4e

File tree

4 files changed

+372
-0
lines changed

4 files changed

+372
-0
lines changed

extension/tensor/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def define_common_targets():
1818
],
1919
exported_headers = [
2020
"tensor.h",
21+
"tensor_accessor.h",
2122
"tensor_ptr.h",
2223
"tensor_ptr_maker.h",
2324
],

extension/tensor/tensor_accessor.h

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/result.h>
13+
14+
namespace executorch {
15+
namespace extension {
16+
namespace internal {
17+
18+
/**
19+
* Base class template storing the underlying data with size and stride helpers.
20+
* Inherited by TensorAccessor<> which requires specialization on rank.
21+
*/
22+
template <typename T, ssize_t N>
23+
class TensorAccessorBase {
24+
public:
25+
/// Returns the size of the underlying tensor at the given dimension.
26+
executorch::aten::SizesType size(ssize_t i) const {
27+
ET_CHECK_MSG(
28+
i < dim_ && i >= 0,
29+
"Dimension outside of [0, %zd], got %zd",
30+
dim_ - 1,
31+
i);
32+
return sizes_[i];
33+
}
34+
35+
/// Returns the stride of the underlying tensor at the given dimension.
36+
executorch::aten::StridesType stride(ssize_t i) const {
37+
ET_CHECK_MSG(
38+
i < dim_ && i >= 0,
39+
"Dimension outside of [0, %zd], got %zd",
40+
dim_ - 1,
41+
i);
42+
return strides_[i];
43+
}
44+
45+
protected:
46+
TensorAccessorBase(
47+
T* data,
48+
const executorch::aten::SizesType* sizes,
49+
const executorch::aten::StridesType* strides,
50+
ssize_t dim)
51+
: data_(data), sizes_(sizes), strides_(strides), dim_(dim) {}
52+
53+
T* data_;
54+
const executorch::aten::SizesType* sizes_;
55+
const executorch::aten::StridesType* strides_;
56+
ssize_t dim_;
57+
};
58+
59+
} // namespace internal
60+
61+
/**
62+
* TensorAccessor template with data type and rank as template parameters. No
63+
* public constructors, can only be created using make_tensor_accessor from a
64+
* given executorch::aten::Tensor. Use operator[] to index and obtain a lower
65+
* rank accessor or the underlying scalar value.
66+
*/
67+
template <typename T, ssize_t N>
68+
class TensorAccessor : public internal::TensorAccessorBase<T, N> {
69+
public:
70+
/**
71+
* Index into the the outer most dimension.
72+
*
73+
* @param i Index.
74+
* @return If N > 1, a TensorAccessor with N-1 dimensions. If N == 1, a
75+
* reference to the underlying scalar. Refer to the TensorAccessor<T, 1>
76+
* specialization.
77+
*/
78+
TensorAccessor<T, N - 1> operator[](ssize_t i) {
79+
return TensorAccessor<T, N - 1>(
80+
this->data_ + this->strides_[0] * i,
81+
this->sizes_ + 1,
82+
this->strides_ + 1,
83+
N - 1);
84+
}
85+
86+
/**
87+
* Index into the the outer most dimension.
88+
*
89+
* @param i Index.
90+
* @return If N > 1, a constant TensorAccessor with N-1 dimensions. If N == 1,
91+
* a constant reference to the underlying scalar. Refer to the
92+
* TensorAccessor<T, 1> specialization.
93+
*/
94+
const TensorAccessor<T, N - 1> operator[](ssize_t i) const {
95+
return TensorAccessor<T, N - 1>(
96+
this->data_ + this->strides_[0] * i,
97+
this->sizes_ + 1,
98+
this->strides_ + 1,
99+
N - 1);
100+
}
101+
102+
private:
103+
TensorAccessor(
104+
T* data,
105+
const executorch::aten::SizesType* sizes,
106+
const executorch::aten::StridesType* strides,
107+
ssize_t dim)
108+
: internal::TensorAccessorBase<T, N>(data, sizes, strides, dim) {}
109+
110+
template <typename T2, ssize_t N2>
111+
friend class TensorAccessor;
112+
113+
template <typename T2, ssize_t N2>
114+
friend executorch::runtime::Result<TensorAccessor<T2, N2>>
115+
make_tensor_accessor(const executorch::aten::Tensor& t);
116+
};
117+
118+
/**
119+
* TensorAccessor specialization for N == 1, where operator[] returns a
120+
* reference to the underlying scalar.
121+
*/
122+
template <typename T>
123+
class TensorAccessor<T, 1> : public internal::TensorAccessorBase<T, 1> {
124+
public:
125+
/**
126+
* Index into the the outer most dimension.
127+
*
128+
* @param i Index.
129+
* @return Reference to the underlying scalar.
130+
*/
131+
T& operator[](ssize_t i) {
132+
return this->data_[this->strides_[0] * i];
133+
}
134+
135+
/**
136+
* Index into the the outer most dimension.
137+
*
138+
* @param i Index.
139+
* @return Constant reference to the underlying scalar.
140+
*/
141+
const T& operator[](ssize_t i) const {
142+
return this->data_[this->strides_[0] * i];
143+
}
144+
145+
private:
146+
TensorAccessor(
147+
T* data,
148+
const executorch::aten::SizesType* sizes,
149+
const executorch::aten::StridesType* strides,
150+
ssize_t dim)
151+
: internal::TensorAccessorBase<T, 1>(data, sizes, strides, dim) {}
152+
153+
template <typename T2, ssize_t N2>
154+
friend class TensorAccessor;
155+
156+
template <typename T2, ssize_t N2>
157+
friend executorch::runtime::Result<TensorAccessor<T2, N2>>
158+
make_tensor_accessor(const executorch::aten::Tensor& t);
159+
};
160+
161+
/**
162+
* Creates a TensorAccessor<T, N> from the given tensor. The number of dimension
163+
* N and the data type T's size must match those of the input tensor. For
164+
* Executorch tensors, non-trivial dimension order is not supported.
165+
*
166+
* @param tensor Origin tensor. The TensorImpl inside must outlive the returned
167+
* TensorAccessor.
168+
* @return TensorAccessor of the input tensor.
169+
* @retval Error::InvalidArgument Mismatch on data type or number of dimensions.
170+
* @retval Error::NotSupported Input tensor has non-trivial dimension onrder.
171+
*/
172+
template <typename T, ssize_t N>
173+
executorch::runtime::Result<TensorAccessor<T, N>> make_tensor_accessor(
174+
const executorch::aten::Tensor& tensor) {
175+
static_assert(
176+
N > 0,
177+
"TensorAccessor is used for indexing tensors, for scalar use *_data_ptr<T>()");
178+
179+
if (N != tensor.dim()) {
180+
ET_LOG(
181+
Error, "Expecting %zd dimensions but tensor has %zd.", N, tensor.dim());
182+
return executorch::runtime::Error::InvalidArgument;
183+
}
184+
185+
if (sizeof(T) != tensor.element_size()) {
186+
ET_LOG(
187+
Error,
188+
"Size of data type template argument (%zd) not equal to tensor element size (%zd)",
189+
sizeof(T),
190+
tensor.element_size());
191+
return executorch::runtime::Error::InvalidArgument;
192+
}
193+
194+
#ifndef USE_ATEN_LIB
195+
auto dim_order = tensor.dim_order();
196+
for (ssize_t i = 0; i < dim_order.size(); i++) {
197+
if (dim_order[i] != i) {
198+
ET_LOG(Error, "Non-trival dim_order not supported.");
199+
return executorch::runtime::Error::NotSupported;
200+
}
201+
}
202+
#endif
203+
204+
T* ptr = nullptr;
205+
if constexpr (std::is_const_v<T>) {
206+
ptr = tensor.const_data_ptr<T>();
207+
} else {
208+
ptr = tensor.mutable_data_ptr<T>();
209+
}
210+
return TensorAccessor<T, N>(
211+
ptr, tensor.sizes().data(), tensor.strides().data(), N);
212+
}
213+
214+
} // namespace extension
215+
} // namespace executorch

extension/tensor/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def define_common_targets():
1313
runtime.cxx_test(
1414
name = "test" + aten_suffix,
1515
srcs = [
16+
"tensor_accessor_test.cpp",
1617
"tensor_ptr_maker_test.cpp",
1718
"tensor_ptr_test.cpp",
1819
],
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/tensor/tensor_accessor.h>
10+
11+
#include <gtest/gtest.h>
12+
#include <vector>
13+
14+
#include <executorch/extension/tensor/tensor_ptr.h>
15+
#include <executorch/runtime/platform/runtime.h>
16+
17+
using executorch::extension::make_tensor_accessor;
18+
using executorch::extension::make_tensor_ptr;
19+
using executorch::extension::TensorAccessor;
20+
21+
class TensorAccessorTest : public ::testing::Test {
22+
protected:
23+
static void SetUpTestSuite() {
24+
executorch::runtime::runtime_init();
25+
}
26+
};
27+
28+
TEST_F(TensorAccessorTest, From1DTensor) {
29+
constexpr int32_t kN = 16;
30+
std::vector<uint8_t> data(kN, 0);
31+
for (int32_t i = 0; i < kN; i++) {
32+
data[i] = i;
33+
}
34+
35+
auto tensor =
36+
make_tensor_ptr({kN}, data.data(), executorch::aten::ScalarType::Byte);
37+
auto tensor_accessor = make_tensor_accessor<uint8_t, 1>(*tensor.get());
38+
EXPECT_TRUE(tensor_accessor.ok());
39+
for (int32_t i = 0; i < kN; i++) {
40+
EXPECT_EQ(tensor_accessor.get()[i], i);
41+
}
42+
}
43+
44+
int32_t
45+
value_at_pos_in_4d_int_tensor(int32_t n, int32_t c, int32_t h, int32_t w) {
46+
// just encode the position into the value, assuming dimensions fit in 8 bits
47+
return (n << 24) | (c << 16) | (h << 8) | w;
48+
}
49+
50+
void check_4d_int_tensor_accessor(
51+
TensorAccessor<int32_t, 4> accessor,
52+
int32_t N,
53+
int32_t C,
54+
int32_t H,
55+
int32_t W) {
56+
for (int32_t n = 0; n < N; n++) {
57+
for (int32_t c = 0; c < C; c++) {
58+
for (int32_t h = 0; h < H; h++) {
59+
for (int32_t w = 0; w < W; w++) {
60+
EXPECT_EQ(
61+
accessor[n][c][h][w], value_at_pos_in_4d_int_tensor(n, c, h, w));
62+
}
63+
}
64+
}
65+
}
66+
}
67+
68+
TEST_F(TensorAccessorTest, From4DTensor) {
69+
constexpr int32_t kN = 2;
70+
constexpr int32_t kC = 8;
71+
constexpr int32_t kH = 4;
72+
constexpr int32_t kW = 6;
73+
std::vector<int32_t> data(kN * kC * kH * kW, 0);
74+
size_t idx = 0;
75+
for (int32_t n = 0; n < kN; n++) {
76+
for (int32_t c = 0; c < kC; c++) {
77+
for (int32_t h = 0; h < kH; h++) {
78+
for (int32_t w = 0; w < kW; w++) {
79+
data[idx++] = value_at_pos_in_4d_int_tensor(n, c, h, w);
80+
}
81+
}
82+
}
83+
}
84+
85+
auto tensor = make_tensor_ptr(
86+
{kN, kC, kH, kW}, data.data(), executorch::aten::ScalarType::Int);
87+
auto accessor = make_tensor_accessor<int32_t, 4>(*tensor.get());
88+
EXPECT_TRUE(accessor.ok());
89+
check_4d_int_tensor_accessor(accessor.get(), kN, kC, kH, kW);
90+
}
91+
92+
#ifdef USE_ATEN_LIB // Non-contiguous tensor is only allowed in ATen mode.
93+
TEST_F(TensorAccessorTest, FromNonContiguousTensor) {
94+
constexpr int32_t kN = 2;
95+
constexpr int32_t kC = 8;
96+
constexpr int32_t kH = 4;
97+
constexpr int32_t kW = 6;
98+
constexpr int32_t kW_padded = 8;
99+
std::vector<int32_t> data(kN * kC * kH * kW_padded, 0);
100+
std::array<executorch::aten::SizesType, 4> sizes = {kN, kC, kH, kW};
101+
std::array<executorch::aten::StridesType, 4> strides = {
102+
kC * kH * kW_padded,
103+
1, // channel last
104+
kC * kW_padded, // width is padded
105+
kC};
106+
107+
size_t idx = 0;
108+
for (int32_t n = 0; n < kN; n++) {
109+
for (int32_t h = 0; h < kH; h++) {
110+
for (int32_t w = 0; w < kW_padded; w++) {
111+
for (int32_t c = 0; c < kC; c++) {
112+
data[idx++] = value_at_pos_in_4d_int_tensor(n, c, h, w);
113+
}
114+
}
115+
}
116+
}
117+
118+
auto tensor = at::from_blob(
119+
data.data(), sizes, strides, at::TensorOptions().dtype(at::kInt));
120+
auto accessor = make_tensor_accessor<int32_t, 4>(tensor);
121+
EXPECT_TRUE(accessor.ok());
122+
check_4d_int_tensor_accessor(accessor.get(), kN, kC, kH, kW);
123+
}
124+
#endif // ifdef USE_ATEN_LIB
125+
126+
TEST_F(TensorAccessorTest, FailOnIncorrectDtypeOrRank) {
127+
constexpr int32_t kN = 16;
128+
std::vector<float> data(kN, 0);
129+
auto tensor = make_tensor_ptr({kN}, data.data());
130+
131+
// Tensor has rank 1 but creating accessor with rank 2.
132+
auto fail1 = make_tensor_accessor<float, 2>(*tensor.get());
133+
EXPECT_FALSE(fail1.ok());
134+
135+
// Tensor has dtype float but creating accoessor with dtype uint8_t.
136+
auto fail2 = make_tensor_accessor<uint8_t, 1>(*tensor.get());
137+
EXPECT_FALSE(fail2.ok());
138+
}
139+
140+
#ifndef USE_ATEN_LIB // Dim order is only defined for portable Tensor
141+
TEST_F(TensorAccessorTest, FailOnNonTrivialDimOrder) {
142+
constexpr int32_t kN = 8;
143+
constexpr int32_t kM = 16;
144+
std::vector<float> data(kN * kM, 0);
145+
auto tensor = make_tensor_ptr(
146+
{kN, kM},
147+
data.data(),
148+
/*dim_order=*/{1, 0},
149+
/*strides=*/{1, kN});
150+
151+
// Non trivial dim order is not supported.
152+
auto fail = make_tensor_accessor<float, 2>(*tensor.get());
153+
EXPECT_FALSE(fail.ok());
154+
}
155+
#endif // ifndef USE_ATEN_LIB

0 commit comments

Comments
 (0)