Skip to content

Commit 6de2ca2

Browse files
sxufacebook-github-bot
authored andcommitted
Introduce torch::executor::TensorAccessor (#6905)
Summary: Replicate the TensorAccessor template from https://github.com/pytorch/pytorch/blob/fc813df1200b530d246eacc710781241c5a9dedf/aten/src/ATen/core/TensorAccessor.h#L73. Differential Revision: D66033489
1 parent 1de96f8 commit 6de2ca2

File tree

4 files changed

+329
-0
lines changed

4 files changed

+329
-0
lines changed

extension/tensor/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def define_common_targets():
1818
],
1919
exported_headers = [
2020
"tensor.h",
21+
"tensor_accessor.h",
2122
"tensor_ptr.h",
2223
"tensor_ptr_maker.h",
2324
],

extension/tensor/tensor_accessor.h

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/result.h>
13+
14+
namespace executorch {
15+
namespace extension {
16+
17+
/**
18+
* Base class template storing the underlying data with size and stride helpers.
19+
* Inherited by TensorAccessor<> which requires specialization on rank.
20+
*/
21+
template <typename T, ssize_t N>
22+
class TensorAccessorBase {
23+
public:
24+
/// Reteurns the size of the underlying tensor at the given dimension.
25+
executorch::aten::SizesType size(ssize_t i) const {
26+
ET_CHECK_MSG(
27+
i < dim_ && i >= 0,
28+
"Dimension outside of [0, %zd], got %zd",
29+
dim_ - 1,
30+
i);
31+
return sizes_[i];
32+
}
33+
34+
/// Reteurns the stride of the underlying tensor at the given dimension.
35+
executorch::aten::StridesType stride(ssize_t i) const {
36+
ET_CHECK_MSG(
37+
i < dim_ && i >= 0,
38+
"Dimension outside of [0, %zd], got %zd",
39+
dim_ - 1,
40+
i);
41+
return strides_[i];
42+
}
43+
44+
protected:
45+
TensorAccessorBase(
46+
T* data,
47+
const executorch::aten::SizesType* sizes,
48+
const executorch::aten::StridesType* strides,
49+
ssize_t dim)
50+
: data_(data), sizes_(sizes), strides_(strides), dim_(dim) {}
51+
52+
T* data_;
53+
const executorch::aten::SizesType* sizes_;
54+
const executorch::aten::StridesType* strides_;
55+
ssize_t dim_;
56+
};
57+
58+
// Forward declarations.
59+
60+
template <typename T, ssize_t N>
61+
class TensorAccessor;
62+
63+
template <typename T, ssize_t N>
64+
executorch::runtime::Result<TensorAccessor<T, N>> make_tensor_accessor(
65+
const executorch::aten::Tensor& t);
66+
67+
/**
68+
* TensorAccessor template with data type and rank as template parameters. No
69+
* public constructors, can only be created using make_tensor_accessor from a
70+
* given executorch::aten::Tensor. Use operator[] to index and obtain a lower
71+
* rank accessor or the underlying scalar value.
72+
*/
73+
template <typename T, ssize_t N>
74+
class TensorAccessor : public TensorAccessorBase<T, N> {
75+
public:
76+
/**
77+
* Index into the the outer most dimension.
78+
*
79+
* @param i Index.
80+
* @return If N > 1, a TensorAccessor with N-1 dimensions. If N == 1, a
81+
* reference to the underlying scalar. Refer to the TensorAccessor<T, 1>
82+
* specialization.
83+
*/
84+
TensorAccessor<T, N - 1> operator[](ssize_t i) {
85+
return TensorAccessor<T, N - 1>(
86+
this->data_ + this->strides_[0] * i,
87+
this->sizes_ + 1,
88+
this->strides_ + 1,
89+
N - 1);
90+
}
91+
92+
/**
93+
* Index into the the outer most dimension.
94+
*
95+
* @param i Index.
96+
* @return If N > 1, a constant TensorAccessor with N-1 dimensions. If N == 1,
97+
* a constant reference to the underlying scalar. Refer to the
98+
* TensorAccessor<T, 1> specialization.
99+
*/
100+
const TensorAccessor<T, N - 1> operator[](ssize_t i) const {
101+
return TensorAccessor<T, N - 1>(
102+
this->data_ + this->strides_[0] * i,
103+
this->sizes_ + 1,
104+
this->strides_ + 1,
105+
N - 1);
106+
}
107+
108+
private:
109+
TensorAccessor(
110+
T* data,
111+
const executorch::aten::SizesType* sizes,
112+
const executorch::aten::StridesType* strides,
113+
ssize_t dim)
114+
: TensorAccessorBase<T, N>(data, sizes, strides, dim) {}
115+
116+
template <typename T2, ssize_t N2>
117+
friend class TensorAccessor;
118+
119+
template <typename T2, ssize_t N2>
120+
friend executorch::runtime::Result<TensorAccessor<T2, N2>>
121+
make_tensor_accessor(const executorch::aten::Tensor& t);
122+
};
123+
124+
/**
125+
* TensorAccessor specialization for N == 1, where operator[] returns a
126+
* reference to the underlying scalar.
127+
*/
128+
template <typename T>
129+
class TensorAccessor<T, 1> : public TensorAccessorBase<T, 1> {
130+
public:
131+
/**
132+
* Index into the the outer most dimension.
133+
*
134+
* @param i Index.
135+
* @return Reference to the underlying scalar.
136+
*/
137+
T& operator[](ssize_t i) {
138+
return this->data_[this->strides_[0] * i];
139+
}
140+
141+
/**
142+
* Index into the the outer most dimension.
143+
*
144+
* @param i Index.
145+
* @return Constant reference to the underlying scalar.
146+
*/
147+
const T& operator[](ssize_t i) const {
148+
return this->data_[this->strides_[0] * i];
149+
}
150+
151+
private:
152+
TensorAccessor(
153+
T* data,
154+
const executorch::aten::SizesType* sizes,
155+
const executorch::aten::StridesType* strides,
156+
ssize_t dim)
157+
: TensorAccessorBase<T, 1>(data, sizes, strides, dim) {}
158+
159+
template <typename T2, ssize_t N2>
160+
friend class TensorAccessor;
161+
162+
template <typename T2, ssize_t N2>
163+
friend executorch::runtime::Result<TensorAccessor<T2, N2>>
164+
make_tensor_accessor(const executorch::aten::Tensor& t);
165+
};
166+
167+
/**
168+
* Creates a TensorAccessor<T, N> from the given tensor. The number of dimension
169+
* N must match the input tensor. For Executorch tensors, does not support
170+
* non-trivial dimension order.
171+
*
172+
* @params tensor Origin tensor.
173+
* @return TensorAccessor of the input tensor.
174+
* @retval Error::InvalidArgument Mismatch on data type or number of dimensions.
175+
* @retval Error::NotSupported Input tensor has non-trivial dimension onrder.
176+
*/
177+
template <typename T, ssize_t N>
178+
executorch::runtime::Result<TensorAccessor<T, N>> make_tensor_accessor(
179+
const executorch::aten::Tensor& tensor) {
180+
static_assert(
181+
N > 0,
182+
"TensorAccessor is used for indexing tensors, for scalar use *_data_ptr<T>()");
183+
184+
if (N != tensor.dim()) {
185+
ET_LOG(
186+
Error, "Expecting %zd dimensions but tensor has %zd.", N, tensor.dim());
187+
return executorch::runtime::Error::InvalidArgument;
188+
}
189+
190+
if (sizeof(T) != tensor.element_size()) {
191+
ET_LOG(
192+
Error,
193+
"Size of data type template argument (%zd) not equal to tensor element size (%zd)",
194+
sizeof(T),
195+
tensor.element_size());
196+
return executorch::runtime::Error::InvalidArgument;
197+
}
198+
199+
#ifndef USE_ATEN_LIB
200+
auto dim_order = tensor.dim_order();
201+
for (ssize_t i = 0; i < dim_order.size(); i++) {
202+
if (dim_order[i] != i) {
203+
ET_LOG(Error, "Non-trival dim_order not supported.");
204+
return executorch::runtime::Error::NotSupported;
205+
}
206+
}
207+
#endif
208+
209+
T* ptr = nullptr;
210+
if constexpr (std::is_const_v<T>) {
211+
ptr = tensor.const_data_ptr<T>();
212+
} else {
213+
ptr = tensor.mutable_data_ptr<T>();
214+
}
215+
return TensorAccessor<T, N>(
216+
ptr, tensor.sizes().data(), tensor.strides().data(), N);
217+
}
218+
219+
} // namespace extension
220+
} // namespace executorch

extension/tensor/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def define_common_targets():
1313
runtime.cxx_test(
1414
name = "test" + aten_suffix,
1515
srcs = [
16+
"tensor_accessor_test.cpp",
1617
"tensor_ptr_maker_test.cpp",
1718
"tensor_ptr_test.cpp",
1819
],
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <executorch/extension/tensor/tensor_accessor.h>
12+
#include <executorch/extension/tensor/tensor_ptr.h>
13+
#include <executorch/runtime/platform/runtime.h>
14+
15+
using namespace ::executorch::extension;
16+
using namespace ::executorch::runtime;
17+
18+
class TensorAccessorTest : public ::testing::Test {
19+
protected:
20+
static void SetUpTestSuite() {
21+
runtime_init();
22+
}
23+
};
24+
25+
TEST_F(TensorAccessorTest, From1DTensor) {
26+
constexpr int32_t kN = 16;
27+
std::vector<uint8_t> data(kN, 0);
28+
for (int32_t i = 0; i < kN; i++) {
29+
data[i] = i;
30+
}
31+
32+
auto tensor =
33+
make_tensor_ptr({kN}, data.data(), executorch::aten::ScalarType::Byte);
34+
auto tensor_accessor = make_tensor_accessor<uint8_t, 1>(*tensor.get());
35+
EXPECT_TRUE(tensor_accessor.ok());
36+
for (int32_t i = 0; i < kN; i++) {
37+
EXPECT_EQ(tensor_accessor.get()[i], i);
38+
}
39+
}
40+
41+
TEST_F(TensorAccessorTest, From4DTensor) {
42+
constexpr int32_t kN = 2;
43+
constexpr int32_t kC = 8;
44+
constexpr int32_t kH = 4;
45+
constexpr int32_t kW = 6;
46+
std::vector<int32_t> data(kN * kC * kH * kW, 0);
47+
auto value_at = [](int32_t n, int32_t c, int32_t h, int32_t w) {
48+
return (n << 24) & (c << 16) & (h << 8) & w;
49+
};
50+
51+
size_t idx = 0;
52+
for (int32_t n = 0; n < kN; n++) {
53+
for (int32_t c = 0; c < kC; c++) {
54+
for (int32_t h = 0; h < kH; h++) {
55+
for (int32_t w = 0; w < kW; w++) {
56+
data[idx++] = value_at(n, c, h, w);
57+
}
58+
}
59+
}
60+
}
61+
62+
auto tensor = make_tensor_ptr(
63+
{kN, kC, kH, kW}, data.data(), executorch::aten::ScalarType::Int);
64+
auto tensor_accessor_4d = make_tensor_accessor<int32_t, 4>(*tensor.get());
65+
EXPECT_TRUE(tensor_accessor_4d.ok());
66+
for (int32_t n = 0; n < kN; n++) {
67+
auto tensor_accessor_3d = tensor_accessor_4d.get()[n];
68+
for (int32_t c = 0; c < kC; c++) {
69+
auto tensor_accessor_2d = tensor_accessor_3d[c];
70+
for (int32_t h = 0; h < kH; h++) {
71+
auto tensor_accessor_1d = tensor_accessor_2d[h];
72+
for (int32_t w = 0; w < kW; w++) {
73+
EXPECT_EQ(tensor_accessor_1d[w], value_at(n, c, h, w));
74+
}
75+
}
76+
}
77+
}
78+
}
79+
80+
TEST_F(TensorAccessorTest, FailOnIncorrectDtypeOrRank) {
81+
constexpr int32_t kN = 16;
82+
std::vector<float> data(kN, 0);
83+
auto tensor = make_tensor_ptr({kN}, data.data());
84+
85+
auto fail1 = make_tensor_accessor<float, 2>(*tensor.get());
86+
EXPECT_FALSE(fail1.ok());
87+
auto fail2 = make_tensor_accessor<uint8_t, 1>(*tensor.get());
88+
EXPECT_FALSE(fail2.ok());
89+
}
90+
91+
TEST_F(TensorAccessorTest, FailOnNonTrivialDimOrder) {
92+
#ifndef USE_ATEN_LIB
93+
constexpr int32_t kN = 16;
94+
constexpr int32_t kM = 16;
95+
std::vector<float> data(kN * kM, 0);
96+
auto tensor = make_tensor_ptr(
97+
{kN, kM},
98+
data.data(),
99+
// dim_order
100+
{1, 0},
101+
// strides
102+
{1, kN});
103+
104+
auto fail = make_tensor_accessor<float, 2>(*tensor.get());
105+
EXPECT_FALSE(fail.ok());
106+
#endif
107+
}

0 commit comments

Comments
 (0)