Skip to content

Commit 150b051

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Update and fix stack
Summary: Clean up implementation of `aten::stack_out`, and allow it to handle input tensor list with different dtypes. Use [ATen stack impl](https://fburl.com/code/9phz8y5w) as a reference. Reviewed By: manuelcandales Differential Revision: D47556485 fbshipit-source-id: 85784ac142672e31fa355cf5a64767c2b4c16c98
1 parent 334d4aa commit 150b051

File tree

5 files changed

+128
-112
lines changed

5 files changed

+128
-112
lines changed

kernels/portable/cpu/op_stack.cpp

Lines changed: 34 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <cstring>
44

5+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
56
#include <executorch/runtime/kernel/kernel_includes.h>
67

78
namespace torch {
@@ -10,127 +11,48 @@ namespace native {
1011

1112
using Tensor = exec_aten::Tensor;
1213

13-
namespace {
14-
15-
// TODO(gasoonjia): Move this to a common spot so all implementation of
16-
// this operator can share it. (e.g., DSP-specific)
17-
/// Asserts that the parameters are valid.
18-
void check_stack_out_args(
19-
exec_aten::ArrayRef<Tensor> tensors,
20-
int64_t dim,
21-
Tensor& out) {
22-
// Stack expects non-empty tensor list
23-
ET_CHECK_MSG(tensors.size() > 0, "Stack expects non-empty tensor list");
24-
25-
// Ensure dim is in range. Use `out` as a proxy for all input tensors, since
26-
// they will all need to have the same number of dimensions besides the dim
27-
// one.
28-
ET_CHECK_MSG(
29-
dim >= 0 && dim < out.dim(),
30-
"dim %" PRId64 " out of range [0,%zd)",
31-
dim,
32-
out.dim());
33-
34-
for (size_t i = 0; i < tensors.size(); i++) {
35-
// All input dtypes must match the output dtype.
36-
ET_CHECK_MSG(
37-
tensors[i].scalar_type() == out.scalar_type(),
38-
"tensors[%zu] dtype %hhd != out dtype %hhd",
39-
i,
40-
tensors[i].scalar_type(),
41-
out.scalar_type());
42-
43-
// All input tensors need to be of the same size
44-
// Also, since we create a new axis in output for stacking, the output.dim()
45-
// should be one larger than input.dim()
46-
// https://pytorch.org/docs/stable/generated/torch.stack.html
47-
ET_CHECK_MSG(
48-
tensors[i].dim() == out.dim() - 1,
49-
"tensors[%zu].dim() %zd != out.dim() - 1 %zd",
50-
i,
51-
tensors[i].dim(),
52-
out.dim() - 1);
53-
54-
// The size of each input tensor should be the same. Here we use `out` as
55-
// proxy for comparsion. Also, the size of output tensor should follow these
56-
// rules:
57-
// - For any input tensor, its size(i) == output.size(i) if i < dim, and its
58-
// size(i) == output.size(i+1) if i >= dim
59-
// - For the cat dimension (output[dim]), its size should be the number of
60-
// input tensors
61-
for (size_t d = 0; d < tensors[i].dim(); d++) {
62-
if (d < dim) {
63-
ET_CHECK_MSG(
64-
tensors[i].size(d) == out.size(d),
65-
"tensors[%zu].size(%zu) %zd != out.size(%zu) %zd | dim = %" PRId64,
66-
i,
67-
d,
68-
tensors[i].size(d),
69-
d,
70-
out.size(d),
71-
dim);
72-
} else {
73-
ET_CHECK_MSG(
74-
tensors[i].size(d) == out.size(d + 1),
75-
"tensors[%zu].size(%zu) %zd != out.size(%zu) %zd | dim = %" PRId64,
76-
i,
77-
d,
78-
tensors[i].size(d),
79-
d + 1,
80-
out.size(d + 1),
81-
dim);
82-
}
83-
}
84-
}
85-
86-
// The size of the stack dimension of the output should be the number of
87-
// input tensors
88-
ET_CHECK_MSG(
89-
out.size(dim) == tensors.size(),
90-
"out.size(%" PRId64 ") %zd != number of input tensors %zu",
91-
dim,
92-
out.size(dim),
93-
tensors.size());
94-
}
95-
} // namespace
96-
97-
/// stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
9814
Tensor& stack_out(
99-
RuntimeContext& context,
15+
RuntimeContext& ctx,
10016
exec_aten::ArrayRef<Tensor> tensors,
10117
int64_t dim,
10218
Tensor& out) {
103-
(void)context;
104-
// Support python-style negative indexing. E.g., for the shape {2, 3, 4},
105-
// dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so on.
19+
(void)ctx;
20+
10621
if (dim < 0) {
10722
dim += out.dim();
10823
}
10924

110-
// Assert that the args are valid.
111-
check_stack_out_args(tensors, dim, out);
112-
113-
// If one tensor is empty tensor, all tensors are empty since they share same
114-
// size. Under that, no need do anything. Just return the out.
115-
if (tensors[0].numel() == 0) {
116-
return out;
117-
}
118-
119-
size_t leading_dim = getLeadingDims(out, dim);
120-
size_t trailing_dim = getTrailingDims(out, dim);
121-
size_t num_of_tensors = tensors.size();
122-
123-
size_t chunk_size = trailing_dim * out.element_size();
124-
125-
char* dst_ptr = out.data_ptr<char>();
126-
127-
for (int i = 0; i < leading_dim; i++) {
128-
for (int j = 0; j < num_of_tensors; j++) {
129-
char* src_ptr = tensors[j].data_ptr<char>() + chunk_size * i;
130-
memcpy(dst_ptr, src_ptr, chunk_size);
131-
dst_ptr += chunk_size;
25+
check_stack_args(tensors, dim, out);
26+
27+
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
28+
size_t expected_out_dim = 0;
29+
get_stack_out_target_size(tensors, dim, expected_out_size, &expected_out_dim);
30+
ET_CHECK(
31+
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok);
32+
33+
const size_t outer = getLeadingDims(out, dim);
34+
const size_t inner = getTrailingDims(out, dim);
35+
const size_t ninputs = tensors.size();
36+
37+
const auto out_type = out.scalar_type();
38+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "stack", CTYPE_OUT, [&] {
39+
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
40+
for (size_t i = 0; i < outer; ++i) {
41+
for (size_t j = 0; j < ninputs; ++j) {
42+
const auto in_type = tensors[j].scalar_type();
43+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "stack", CTYPE_IN, [&] {
44+
const CTYPE_IN* const in_ptr =
45+
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
46+
47+
for (size_t k = 0; k < inner; ++k) {
48+
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
49+
}
50+
out_ptr += inner;
51+
});
52+
}
13253
}
133-
}
54+
});
55+
13456
return out;
13557
}
13658

kernels/portable/cpu/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,9 @@ _ATEN_OPS = (
652652
),
653653
op_target(
654654
name = "op_stack",
655+
deps = [
656+
"//executorch/kernels/portable/cpu/util:copy_ops_util",
657+
],
655658
),
656659
op_target(
657660
name = "op_sub",
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
#include <cstring>
4+
5+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
6+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
7+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
8+
#include <executorch/runtime/platform/assert.h>
9+
10+
namespace torch {
11+
namespace executor {
12+
13+
using Tensor = exec_aten::Tensor;
14+
15+
void check_stack_args(
16+
exec_aten::ArrayRef<Tensor> tensors,
17+
int64_t dim,
18+
Tensor& out) {
19+
// Ensure the input tensors list is non-empty
20+
ET_CHECK(tensors.size() > 0);
21+
22+
// All input tensors need to be of the same size
23+
// https://pytorch.org/docs/stable/generated/torch.stack.html
24+
for (size_t i = 0; i < tensors.size(); i++) {
25+
// All input dtypes must be castable to the output dtype.
26+
ET_CHECK(canCast(tensors[i].scalar_type(), out.scalar_type()));
27+
28+
ET_CHECK(tensors[i].dim() == tensors[0].dim());
29+
for (size_t d = 0; d < tensors[i].dim(); d++) {
30+
ET_CHECK(tensors[i].size(d) == tensors[0].size(d));
31+
}
32+
}
33+
34+
// The output tensor will have a dimension inserted, so dim should be between
35+
// 0 and ndim_of_inputs + 1
36+
ET_CHECK(dim >= 0 && dim < tensors[0].dim() + 1);
37+
}
38+
39+
void get_stack_out_target_size(
40+
exec_aten::ArrayRef<Tensor> tensors,
41+
int64_t dim,
42+
Tensor::SizesType* out_sizes,
43+
size_t* out_ndim) {
44+
*out_ndim = tensors[0].dim() + 1;
45+
46+
for (size_t d = 0; d < *out_ndim; ++d) {
47+
if (d < dim) {
48+
out_sizes[d] = tensors[0].size(d);
49+
} else if (d == dim) {
50+
out_sizes[d] = tensors.size();
51+
} else {
52+
out_sizes[d] = tensors[0].size(d - 1);
53+
}
54+
}
55+
}
56+
57+
} // namespace executor
58+
} // namespace torch
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#pragma once
2+
3+
#include <executorch/runtime/kernel/kernel_includes.h>
4+
5+
namespace torch {
6+
namespace executor {
7+
8+
void check_stack_args(
9+
exec_aten::ArrayRef<Tensor> tensors,
10+
int64_t dim,
11+
Tensor& out);
12+
13+
void get_stack_out_target_size(
14+
exec_aten::ArrayRef<Tensor> tensors,
15+
int64_t dim,
16+
Tensor::SizesType* out_sizes,
17+
size_t* out_ndim);
18+
19+
} // namespace executor
20+
} // namespace torch

kernels/portable/cpu/util/targets.bzl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ def define_common_targets():
3737
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],
3838
)
3939

40+
runtime.cxx_library(
41+
name = "copy_ops_util",
42+
srcs = ["copy_ops_util.cpp"],
43+
exported_headers = [
44+
"copy_ops_util.h",
45+
],
46+
compiler_flags = ["-Wno-missing-prototypes"],
47+
deps = [
48+
"//executorch/runtime/kernel:kernel_includes",
49+
],
50+
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],
51+
)
52+
4053
runtime.cxx_library(
4154
name = "transpose_util",
4255
exported_headers = [

0 commit comments

Comments
 (0)