Skip to content

Commit beabe34

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add op: repeat_interleave.Tensor_out
Differential Revision: D67025538
1 parent f6a87ac commit beabe34

File tree

6 files changed

+167
-0
lines changed

6 files changed

+167
-0
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@
313313

314314
- op: repeat.out
315315

316+
- op: repeat_interleave.Tensor_out
317+
316318
- op: reflection_pad1d.out
317319

318320
- op: reflection_pad2d.out
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace native {
14+
namespace {
15+
16+
bool check_repeat_interleave_args(
17+
const Tensor& repeats,
18+
int64_t output_size_value,
19+
int64_t repeats_sum,
20+
Tensor& out) {
21+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
22+
repeats.scalar_type() == ScalarType::Int ||
23+
repeats.scalar_type() == ScalarType::Long,
24+
"repeats must be int or long");
25+
ET_LOG_MSG_AND_RETURN_IF_FALSE(repeats.dim() == 1, "repeats must be 1D");
26+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
27+
output_size_value == repeats_sum,
28+
"output_size, if provided, must be equal to repeats.sum()");
29+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(repeats, out));
30+
31+
if (repeats.scalar_type() == ScalarType::Long) {
32+
const int64_t* const repeats_data = repeats.const_data_ptr<int64_t>();
33+
for (size_t i = 0; i < repeats.numel(); ++i) {
34+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
35+
repeats_data[i] >= 0, "repeats cannot be negative");
36+
}
37+
} else {
38+
const int32_t* const repeats_data = repeats.const_data_ptr<int32_t>();
39+
for (size_t i = 0; i < repeats.numel(); ++i) {
40+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
41+
repeats_data[i] >= 0, "repeats cannot be negative");
42+
}
43+
}
44+
45+
return true;
46+
}
47+
48+
} // namespace
49+
50+
using Tensor = exec_aten::Tensor;
51+
52+
Tensor& repeat_interleave_Tensor_out(
53+
KernelRuntimeContext& ctx,
54+
const Tensor& repeats,
55+
exec_aten::optional<int64_t> output_size,
56+
Tensor& out) {
57+
(void)ctx;
58+
59+
int64_t repeats_sum = 0;
60+
61+
constexpr auto name = "repeat_interleave.Tensor_out";
62+
63+
ET_SWITCH_TWO_TYPES(Int, Long, repeats.scalar_type(), ctx, name, CTYPE, [&] {
64+
const CTYPE* repeats_data = repeats.const_data_ptr<CTYPE>();
65+
for (size_t ix = 0; ix < repeats.numel(); ++ix) {
66+
repeats_sum += static_cast<int64_t>(repeats_data[ix]);
67+
}
68+
});
69+
70+
int64_t output_size_value =
71+
output_size.has_value() ? output_size.value() : repeats_sum;
72+
73+
ET_KERNEL_CHECK(
74+
ctx,
75+
check_repeat_interleave_args(
76+
repeats, output_size_value, repeats_sum, out),
77+
InvalidArgument,
78+
out);
79+
80+
ET_KERNEL_CHECK(
81+
ctx, tensors_have_same_dim_order(repeats, out), InvalidArgument, out);
82+
83+
ET_KERNEL_CHECK(
84+
ctx, tensor_is_default_dim_order(repeats), InvalidArgument, out);
85+
86+
ET_KERNEL_CHECK_MSG(
87+
ctx,
88+
resize_tensor(
89+
out, {static_cast<exec_aten::SizesType>(output_size_value)}) ==
90+
Error::Ok,
91+
InvalidArgument,
92+
out,
93+
"Failed to resize output tensor.");
94+
95+
ET_SWITCH_TWO_TYPES(Int, Long, repeats.scalar_type(), ctx, name, CTYPE, [&] {
96+
const CTYPE* repeats_data = repeats.const_data_ptr<CTYPE>();
97+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
98+
size_t out_ix = 0;
99+
for (size_t ix = 0; ix < repeats.numel(); ix++) {
100+
for (CTYPE i = 0; i < repeats_data[ix]; i++, out_ix++) {
101+
out_data[out_ix] = static_cast<CTYPE>(ix);
102+
}
103+
}
104+
});
105+
106+
return out;
107+
}
108+
109+
} // namespace native
110+
} // namespace executor
111+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,11 @@
712712
- arg_meta: null
713713
kernel_name: torch::executor::repeat_out
714714

715+
- op: repeat_interleave.Tensor_out
716+
kernels:
717+
- arg_meta: null
718+
kernel_name: torch::executor::repeat_interleave_Tensor_out
719+
715720
- op: reflection_pad1d.out
716721
kernels:
717722
- arg_meta: null
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/kernels/test/supported_features.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16+
17+
using namespace ::testing;
18+
using exec_aten::optional;
19+
using exec_aten::ScalarType;
20+
using exec_aten::Tensor;
21+
using torch::executor::testing::TensorFactory;
22+
23+
class OpRepeatInterleaveTensorOutTest : public OperatorTest {
24+
protected:
25+
Tensor& op_repeat_out(
26+
const Tensor& repeats,
27+
optional<int64_t> output_size,
28+
Tensor& out) {
29+
return torch::executor::aten::repeat_interleave_outf(
30+
context_, repeats, output_size, out);
31+
}
32+
};
33+
34+
TEST_F(OpRepeatInterleaveTensorOutTest, SmokeTest) {
35+
TensorFactory<ScalarType::Int> tf;
36+
37+
Tensor repeats = tf.make({3}, {2, 3, 1});
38+
39+
std::vector<int64_t> repeats_vec = {3, 4, 5, 6};
40+
Tensor out = tf.zeros({6});
41+
Tensor expected = tf.make({6}, {0, 0, 1, 1, 1, 2});
42+
Tensor ret = op_repeat_out(repeats, 6, out);
43+
EXPECT_TENSOR_EQ(ret, out);
44+
EXPECT_TENSOR_EQ(ret, expected);
45+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def define_common_targets():
283283
_common_op_test("op_relu_test", ["aten", "portable"])
284284
_common_op_test("op_remainder_test", ["aten", "portable"])
285285
_common_op_test("op_repeat_test", ["aten", "portable"])
286+
_common_op_test("op_repeat_interleave_test", ["aten", "portable"])
286287
_common_op_test("op_reflection_pad1d_test", ["aten", "portable"])
287288
_common_op_test("op_reflection_pad2d_test", ["aten", "portable"])
288289
_common_op_test("op_reflection_pad3d_test", ["aten", "portable"])

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,9 @@ ATEN_OPS = (
10011001
"//executorch/kernels/portable/cpu/util:repeat_util",
10021002
],
10031003
),
1004+
op_target(
1005+
name = "op_repeat_interleave",
1006+
),
10041007
op_target(
10051008
name = "op_replication_pad1d",
10061009
deps = [

0 commit comments

Comments
 (0)