Skip to content

Commit f326ee1

Browse files
authored
Adopt the new tensor API for aten_util.
Differential Revision: D62168422 Pull Request resolved: #5062
1 parent ae05ed8 commit f326ee1

File tree

6 files changed

+20
-47
lines changed

6 files changed

+20
-47
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ cmake_dependent_option(
228228
)
229229

230230
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
231+
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
231232
set(EXECUTORCH_BUILD_KERNELS_CUSTOM ON)
232233
endif()
233234

extension/aten_util/make_aten_functor_from_et_functor.h

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#endif
2121
#include <ATen/native/Resize.h>
2222
#include <executorch/extension/kernel_util/type_list.h>
23+
#include <executorch/extension/tensor/tensor.h>
2324
#include <executorch/runtime/core/evalue.h>
24-
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
2525
#include <torch/torch.h>
2626

2727
namespace executorch {
@@ -105,48 +105,20 @@ struct type_convert<
105105
typename remove_const_ref<ETensor>::type,
106106
torch::executor::Tensor>>>
107107
final {
108-
explicit type_convert(ATensor value) : value_(value) {
109-
auto sizes =
110-
std::make_shared<std::vector<torch::executor::Tensor::SizesType>>(
111-
value_.sizes().begin(), value_.sizes().end());
112-
const ssize_t dim = sizes->size();
113-
auto dim_order =
114-
std::make_shared<std::vector<torch::executor::Tensor::DimOrderType>>(
115-
dim);
116-
auto strides =
117-
std::make_shared<std::vector<torch::executor::Tensor::StridesType>>(
118-
dim);
119-
120-
std::iota(dim_order->begin(), dim_order->end(), 0);
121-
::executorch::runtime::dim_order_to_stride_nocheck(
122-
sizes->data(), dim_order->data(), dim, strides->data());
123-
124-
auto tensor_impl = std::make_shared<torch::executor::TensorImpl>(
125-
static_cast<torch::executor::ScalarType>(value_.scalar_type()),
126-
sizes->size(),
127-
sizes->data(),
128-
value_.mutable_data_ptr(),
129-
dim_order->data(),
130-
strides->data());
131-
132-
converted_ = std::unique_ptr<
133-
torch::executor::Tensor,
134-
std::function<void(torch::executor::Tensor*)>>(
135-
new torch::executor::Tensor(tensor_impl.get()),
136-
[sizes, dim_order, strides, tensor_impl](
137-
torch::executor::Tensor* pointer) { delete pointer; });
138-
}
108+
explicit type_convert(ATensor value)
109+
: value_(value),
110+
converted_(from_blob(
111+
value_.mutable_data_ptr(),
112+
{value_.sizes().begin(), value_.sizes().end()},
113+
::torch::executor::ScalarType(value_.scalar_type()))) {}
139114

140115
ETensor call() {
141116
return *converted_;
142117
}
143118

144119
private:
145120
ATensor value_;
146-
std::unique_ptr<
147-
torch::executor::Tensor,
148-
std::function<void(torch::executor::Tensor*)>>
149-
converted_;
121+
TensorPtr converted_;
150122
};
151123

152124
// Tensors: ETen to ATen.
@@ -158,15 +130,14 @@ struct type_convert<
158130
std::is_same_v<typename remove_const_ref<ATensor>::type, at::Tensor> &&
159131
std::is_same_v<
160132
typename remove_const_ref<ETensor>::type,
161-
torch::executor::Tensor>>>
133+
::torch::executor::Tensor>>>
162134
final {
163135
explicit type_convert(ETensor value)
164-
: value_(value), sizes_(value_.sizes().begin(), value_.sizes().end()) {
165-
converted_ = at::from_blob(
166-
value_.mutable_data_ptr(),
167-
sizes_,
168-
static_cast<c10::ScalarType>(value_.scalar_type()));
169-
}
136+
: value_(value),
137+
converted_(at::from_blob(
138+
value_.mutable_data_ptr(),
139+
std::vector<int64_t>{value_.sizes().begin(), value_.sizes().end()},
140+
c10::ScalarType(value_.scalar_type()))) {}
170141

171142
ATensor call() {
172143
return converted_;
@@ -175,7 +146,6 @@ struct type_convert<
175146
private:
176147
ETensor value_;
177148
at::Tensor converted_;
178-
std::vector<int64_t> sizes_;
179149
};
180150

181151
// Optionals: ATen to ETen.

extension/aten_util/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def define_common_targets():
2727
],
2828
exported_deps = [
2929
"//executorch/extension/kernel_util:kernel_util",
30+
"//executorch/extension/tensor:tensor",
3031
"//executorch/runtime/core:core",
3132
"//executorch/runtime/core:evalue",
3233
"//executorch/runtime/core/exec_aten:lib",

extension/llm/custom_ops/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
9494
endif()
9595

9696
target_link_libraries(
97-
custom_ops_aot_lib PUBLIC cpublas torch extension_threadpool
97+
custom_ops_aot_lib PUBLIC cpublas torch extension_tensor
98+
extension_threadpool
9899
)
99100
if(WIN32)
100101
# There is no direct replacement for libpthread.so on Windows. For the

extension/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ endif()
1818

1919
list(TRANSFORM _extension_tensor__srcs PREPEND "${EXECUTORCH_ROOT}/")
2020
add_library(extension_tensor ${_extension_tensor__srcs})
21-
target_link_libraries(extension_tensor executorch)
21+
target_link_libraries(extension_tensor executorch_no_prim_ops)
2222
target_include_directories(extension_tensor PUBLIC ${EXECUTORCH_ROOT}/..)
2323
target_compile_options(extension_tensor PUBLIC ${_common_compile_options})
2424

extension/tensor/tensor_impl_ptr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ TensorImplPtr make_tensor_impl_ptr(
9191
tensor_impl.release(),
9292
TensorImplPtrDeleter{
9393
std::unique_ptr<void, std::function<void(void*)>>(
94-
data, std::move(deleter) ?: noop_deleter),
94+
data, deleter ? std::move(deleter) : noop_deleter),
9595
std::move(sizes),
9696
std::move(dim_order),
9797
std::move(strides)});

0 commit comments

Comments
 (0)