Skip to content

Commit a5cb1e2

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Add squeeze_copy.dims (#655)
Summary: Pull Request resolved: #655 Add `squeeze_copy.dims`, which `squeeze_copy.dim` will decompose to. Reviewed By: manuelcandales Differential Revision: D49988507 fbshipit-source-id: b273db0a6cb9286e583c996a6aec58c0eeec4ce1
1 parent 6871938 commit a5cb1e2

File tree

8 files changed

+284
-92
lines changed

8 files changed

+284
-92
lines changed

exir/dialects/edge/op/sample_input.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,15 @@
11091109
Return(ArgType.Tensor, size=[2]),
11101110
],
11111111
},
1112+
"squeeze_copy.dims": { # (Tensor self, int[] dims) -> Tensor
1113+
"args": [
1114+
InArg(ArgType.Tensor, size=[1, 2, 1, 5]),
1115+
InArg(ArgType.Param, value=[0, 2]),
1116+
],
1117+
"returns": [
1118+
Return(ArgType.Tensor, size=[2, 5]),
1119+
],
1120+
},
11121121
"stack.default": { # (Tensor[] tensors, int dim=0) -> Tensor
11131122
"args": [
11141123
InArg(

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@
290290

291291
- op: squeeze_copy.dim_out
292292

293+
- op: squeeze_copy.dims_out
294+
293295
- op: squeeze_copy.out
294296

295297
- op: stack.out

kernels/portable/cpu/op_squeeze_copy.cpp

Lines changed: 45 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cstdint>
1111
#include <cstring>
1212

13+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
1314
#include <executorch/runtime/kernel/kernel_includes.h>
1415

1516
namespace torch {
@@ -18,113 +19,65 @@ namespace native {
1819

1920
using Tensor = exec_aten::Tensor;
2021

21-
namespace {
22-
23-
void check_squeeze_copy_dim_out(
24-
const Tensor input,
25-
int64_t dim,
26-
const Tensor out) {
27-
if (input.dim() != 0 && input.size(dim) == 1) {
28-
ET_CHECK(input.dim() == out.dim() + 1);
29-
30-
for (size_t d = 0; d < out.dim(); ++d) {
31-
if (d < dim) {
32-
// d < dim
33-
ET_CHECK_MSG(
34-
input.size(d) == out.size(d),
35-
"input.size(%zu) %zd != out.size(%zu) %zd | dim = %" PRId64,
36-
d,
37-
input.size(d),
38-
d,
39-
out.size(d),
40-
dim);
41-
} else {
42-
// d >= dim
43-
ET_CHECK_MSG(
44-
input.size(d + 1) == out.size(d),
45-
"input.size(%zu) %zd != out.size(%zu) %zd | dim = %" PRId64,
46-
d + 1,
47-
input.size(d),
48-
d,
49-
out.size(d),
50-
dim);
51-
}
52-
}
53-
} else {
54-
ET_CHECK(input.dim() == out.dim());
55-
56-
for (size_t d = 0; d < out.dim(); ++d) {
57-
ET_CHECK_MSG(
58-
input.size(d) == out.size(d),
59-
"input.size(%zu) %zd != out.size(%zu) %zd | dim = %" PRId64,
60-
d,
61-
input.size(d),
62-
d,
63-
out.size(d),
64-
dim);
65-
}
66-
}
67-
}
68-
} // namespace
69-
70-
//
71-
// squeeze_copy.dim_out(Tensor self, int dim, Tensor(a!) out) -> Tensor(a!)
72-
//
7322
Tensor& squeeze_copy_dim_out(
7423
RuntimeContext& ctx,
75-
const Tensor& self,
24+
const Tensor& in,
7625
int64_t dim,
7726
Tensor& out) {
7827
(void)ctx;
79-
Tensor::SizesType expected_output_size[kTensorDimensionLimit];
80-
81-
// The input and out shall share same dtype
82-
ET_CHECK_SAME_DTYPE2(self, out);
83-
84-
// A valid dim must be in [-self.dim(), self.dim())
85-
if (self.dim() == 0 && dim == -1) {
86-
dim = 0;
87-
}
88-
ET_CHECK_MSG(
89-
(self.dim() == 0 && dim == 0) || (dim >= -self.dim() && dim < self.dim()),
90-
"dim %" PRId64 " out of range [-%zd,%zd)",
91-
dim,
92-
self.dim(),
93-
self.dim());
9428

29+
// TODO(ssjia): use nonzero_dim() instead
9530
if (dim < 0) {
96-
dim += self.dim();
31+
dim += in.dim();
9732
}
9833

99-
size_t expected_out_dim = (self.dim() == 0 || self.size(dim) != 1)
100-
? self.dim()
101-
: std::max<ssize_t>(self.dim() - 1, 0);
34+
ET_KERNEL_CHECK(
35+
ctx, check_squeeze_copy_dim_args(in, dim, out), InvalidArgument, out);
36+
37+
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
38+
size_t expected_out_dim = 0;
39+
get_squeeze_copy_dim_out_target_size(
40+
in, dim, expected_out_size, &expected_out_dim);
41+
ET_KERNEL_CHECK(
42+
ctx,
43+
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
44+
InvalidArgument,
45+
out);
10246

103-
if (dim == self.dim() || self.size(dim) != 1) {
104-
for (size_t i = 0; i < expected_out_dim; ++i) {
105-
expected_output_size[i] = self.size(i);
106-
}
107-
} else {
108-
// 0 <= dim < self.dim() AND self.size(dim) == 1
109-
for (size_t i = 0; i < expected_out_dim; ++i) {
110-
if (i < dim) {
111-
expected_output_size[i] = self.size(i);
112-
} else {
113-
// Squeeze the given dimension 'dim'
114-
expected_output_size[i] = self.size(i + 1);
115-
}
116-
}
47+
if (in.nbytes() > 0) {
48+
// Note that this check is important. It's valid for a tensor with numel 0
49+
// to have a null data pointer, but in some environments it's invalid to
50+
// pass a null pointer to memcpy() even when the size is zero.
51+
memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
11752
}
118-
ET_CHECK_MSG(
119-
Error::Ok == resize_tensor(out, {expected_output_size, expected_out_dim}),
120-
"Failed to resize output tensor.");
121-
check_squeeze_copy_dim_out(self, dim, out);
53+
return out;
54+
}
55+
56+
Tensor& squeeze_copy_dims_out(
57+
RuntimeContext& ctx,
58+
const Tensor& in,
59+
exec_aten::ArrayRef<int64_t> dims,
60+
Tensor& out) {
61+
(void)ctx;
62+
63+
ET_KERNEL_CHECK(
64+
ctx, check_squeeze_copy_dims_args(in, dims, out), InvalidArgument, out);
65+
66+
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
67+
size_t expected_out_dim = 0;
68+
get_squeeze_copy_dims_out_target_size(
69+
in, dims, expected_out_size, &expected_out_dim);
70+
ET_KERNEL_CHECK(
71+
ctx,
72+
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
73+
InvalidArgument,
74+
out);
12275

123-
if (self.nbytes() > 0) {
76+
if (in.nbytes() > 0) {
12477
// Note that this check is important. It's valid for a tensor with numel 0
12578
// to have a null data pointer, but in some environments it's invalid to
12679
// pass a null pointer to memcpy() even when the size is zero.
127-
memcpy(out.mutable_data_ptr(), self.const_data_ptr(), self.nbytes());
80+
memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
12881
}
12982
return out;
13083
}

kernels/portable/cpu/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,9 @@ _ATEN_OPS = (
705705
),
706706
op_target(
707707
name = "op_squeeze_copy",
708+
deps = [
709+
"//executorch/kernels/portable/cpu/util:copy_ops_util",
710+
],
708711
),
709712
op_target(
710713
name = "op_stack",

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,111 @@ void get_split_with_sizes_copy_out_target_size(
204204
out_sizes[dim] = split_size;
205205
}
206206

207+
bool check_squeeze_copy_dim_args(
208+
const Tensor in,
209+
int64_t dim,
210+
const Tensor out) {
211+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
212+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
213+
214+
return true;
215+
}
216+
217+
void get_squeeze_copy_dim_out_target_size(
218+
const Tensor in,
219+
int64_t dim,
220+
Tensor::SizesType* out_sizes,
221+
size_t* out_ndim) {
222+
// For 0 dim tensors, the output should also be 0 dim.
223+
if (in.dim() == 0) {
224+
*out_ndim = 0;
225+
return;
226+
}
227+
228+
// Specified dim is only removed if the size at the given dim is 1.
229+
if (in.size(dim) == 1) {
230+
*out_ndim = in.dim() - 1;
231+
} else {
232+
*out_ndim = in.dim();
233+
}
234+
235+
size_t out_d = 0;
236+
for (size_t in_d = 0; in_d < in.dim(); ++in_d) {
237+
if (in_d != dim || in.size(in_d) > 1) {
238+
out_sizes[out_d] = in.size(in_d);
239+
++out_d;
240+
}
241+
}
242+
}
243+
244+
bool check_squeeze_copy_dims_args(
245+
const Tensor in,
246+
const exec_aten::ArrayRef<int64_t> dims,
247+
const Tensor out) {
248+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
249+
250+
const int64_t dim_adjust = in.dim() == 0 ? 1 : in.dim();
251+
for (size_t i = 0; i < dims.size(); ++i) {
252+
// TODO(ssjia): use nonzero_dim() instead
253+
const int64_t dim = dims[i] < 0 ? dims[i] + dim_adjust : dims[i];
254+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
255+
256+
// Check that a dim does not appear twice in dims
257+
for (size_t j = 0; j < dims.size(); ++j) {
258+
if (i != j) {
259+
const int64_t dim_temp = dims[j] < 0 ? dims[j] + dim_adjust : dims[j];
260+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
261+
dim != dim_temp,
262+
"dim %" PRId64 " appears multiple times in dims!",
263+
dim);
264+
}
265+
}
266+
}
267+
268+
return true;
269+
}
270+
271+
void get_squeeze_copy_dims_out_target_size(
272+
const Tensor in,
273+
const exec_aten::ArrayRef<int64_t> dims,
274+
Tensor::SizesType* out_sizes,
275+
size_t* out_ndim) {
276+
// For 0 dim tensors, the output should also be 0 dim.
277+
if (in.dim() == 0) {
278+
*out_ndim = 0;
279+
return;
280+
}
281+
282+
int64_t dim_adjust = in.dim() == 0 ? 1 : in.dim();
283+
// A dim is only removed if the size at the given dim is 1.
284+
Tensor::SizesType dims_to_remove = 0;
285+
for (size_t i = 0; i < dims.size(); ++i) {
286+
// TODO(ssjia): use nonzero_dim() instead
287+
int64_t dim = dims[i] < 0 ? dims[i] + dim_adjust : dims[i];
288+
if (in.size(dim) == 1) {
289+
++dims_to_remove;
290+
}
291+
}
292+
*out_ndim = in.dim() - dims_to_remove;
293+
294+
size_t out_d = 0;
295+
for (size_t in_d = 0; in_d < in.dim(); ++in_d) {
296+
bool in_d_in_dims = false;
297+
for (size_t i = 0; i < dims.size(); ++i) {
298+
// TODO(ssjia): use nonzero_dim() instead
299+
int64_t dim = dims[i] < 0 ? dims[i] + dim_adjust : dims[i];
300+
if (in_d == dim) {
301+
in_d_in_dims = true;
302+
break;
303+
}
304+
}
305+
if (!in_d_in_dims || in.size(in_d) > 1) {
306+
out_sizes[out_d] = in.size(in_d);
307+
++out_d;
308+
}
309+
}
310+
}
311+
207312
bool check_stack_args(
208313
exec_aten::ArrayRef<Tensor> tensors,
209314
int64_t dim,

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,28 @@ void get_split_with_sizes_copy_out_target_size(
5656
Tensor::SizesType* out_sizes,
5757
size_t* out_ndim);
5858

59+
bool check_squeeze_copy_dim_args(
60+
const Tensor in,
61+
int64_t dim,
62+
const Tensor out);
63+
64+
void get_squeeze_copy_dim_out_target_size(
65+
const Tensor in,
66+
int64_t dim,
67+
Tensor::SizesType* out_sizes,
68+
size_t* out_ndim);
69+
70+
bool check_squeeze_copy_dims_args(
71+
const Tensor in,
72+
const exec_aten::ArrayRef<int64_t> dims,
73+
const Tensor out);
74+
75+
void get_squeeze_copy_dims_out_target_size(
76+
const Tensor in,
77+
const exec_aten::ArrayRef<int64_t> dims,
78+
Tensor::SizesType* out_sizes,
79+
size_t* out_ndim);
80+
5981
bool check_stack_args(
6082
exec_aten::ArrayRef<Tensor> tensors,
6183
int64_t dim,

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,11 @@
647647
- arg_meta: null
648648
kernel_name: torch::executor::squeeze_copy_dim_out
649649

650+
- op: squeeze_copy.dims_out
651+
kernels:
652+
- arg_meta: null
653+
kernel_name: torch::executor::squeeze_copy_dims_out
654+
650655
- op: stack.out
651656
kernels:
652657
- arg_meta: null

0 commit comments

Comments
 (0)