Skip to content

Commit d3df91a

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix & cleanup op ones (#699)
Summary: Pull Request resolved: #699 Resize out tensor ghstack-source-id: 203341575 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D49735857 fbshipit-source-id: 75c30806c541f097641a73cc502443834d071789
1 parent 75df54d commit d3df91a

File tree

1 file changed

+11
-54
lines changed

1 file changed

+11
-54
lines changed

kernels/portable/cpu/op_ones.cpp

Lines changed: 11 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,67 +8,24 @@
88

99
#include <executorch/runtime/kernel/kernel_includes.h>
1010

11-
#include <cstring>
12-
1311
namespace torch {
1412
namespace executor {
1513
namespace native {
1614

17-
using exec_aten::Scalar;
18-
using ScalarType = exec_aten::ScalarType;
19-
20-
namespace {
21-
22-
/**
23-
* Checks sizes passed to `ones_out` (`size_int64_t`) matches those within
24-
* `out` tensor (`size_int32_t`).
25-
*/
26-
void size_check(
27-
exec_aten::ArrayRef<int64_t> size_int64_t,
28-
exec_aten::ArrayRef<int32_t> size_int32_t) {
29-
ET_CHECK(size_int64_t.size() == size_int32_t.size());
30-
for (int i = 0; i < size_int64_t.size(); i++) {
31-
ET_CHECK(((int64_t)size_int32_t[i] == size_int64_t[i]));
32-
}
33-
};
34-
35-
/**
36-
* Fills the `out` tensor with value 1.
37-
*/
38-
template <class CTYPE>
39-
void ones_kernel(Tensor& out) {
40-
// Create pointer over `out` data with `CTYPE`.
41-
auto data_out = out.mutable_data_ptr<CTYPE>();
42-
43-
// Set each element of the tensor to the "1" value for the type.
44-
for (size_t i = 0; i < out.numel(); i++) {
45-
data_out[i] = static_cast<CTYPE>(1);
46-
}
47-
};
48-
49-
} // namespace
50-
51-
/**
52-
* `ones_out` implementation.
53-
*/
5415
Tensor& ones_out(RuntimeContext& ctx, IntArrayRef size, Tensor& out) {
5516
(void)ctx;
56-
size_check(size, out.sizes());
57-
58-
#define ONES_OUT(ctype, dtype) \
59-
case ScalarType::dtype: \
60-
ones_kernel<ctype>(out); \
61-
break;
6217

63-
switch (out.scalar_type()) {
64-
ET_FORALL_REAL_TYPES_AND(Bool, ONES_OUT)
65-
default:
66-
ET_CHECK_MSG(
67-
false,
68-
"out tensor should be a real or bool dtype, but got %" PRId8,
69-
static_cast<int8_t>(out.scalar_type()));
70-
}
71-
#undef ONES_OUT
18+
// Resize for dynamic shape
19+
ET_KERNEL_CHECK(
20+
ctx, resize_tensor(out, size) == Error::Ok, InvalidArgument, out);
21+
22+
ScalarType out_type = out.scalar_type();
23+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE, [&] {
24+
auto out_data = out.mutable_data_ptr<CTYPE>();
25+
for (size_t i = 0; i < out.numel(); i++) {
26+
out_data[i] = static_cast<CTYPE>(1);
27+
}
28+
});
7229

7330
return out;
7431
}

0 commit comments

Comments
 (0)