|
8 | 8 |
|
9 | 9 | #include <executorch/runtime/kernel/kernel_includes.h>
|
10 | 10 |
|
11 |
| -#include <cstring> |
12 |
| - |
13 | 11 | namespace torch {
|
14 | 12 | namespace executor {
|
15 | 13 | namespace native {
|
16 | 14 |
|
17 |
| -using exec_aten::Scalar; |
18 |
| -using ScalarType = exec_aten::ScalarType; |
19 |
| - |
20 |
| -namespace { |
21 |
| - |
22 |
| -/** |
23 |
| - * Checks sizes passed to `ones_out` (`size_int64_t`) matches those within |
24 |
| - * `out` tensor (`size_int32_t`). |
25 |
| - */ |
26 |
| -void size_check( |
27 |
| - exec_aten::ArrayRef<int64_t> size_int64_t, |
28 |
| - exec_aten::ArrayRef<int32_t> size_int32_t) { |
29 |
| - ET_CHECK(size_int64_t.size() == size_int32_t.size()); |
30 |
| - for (int i = 0; i < size_int64_t.size(); i++) { |
31 |
| - ET_CHECK(((int64_t)size_int32_t[i] == size_int64_t[i])); |
32 |
| - } |
33 |
| -}; |
34 |
| - |
35 |
| -/** |
36 |
| - * Fills the `out` tensor with value 1. |
37 |
| - */ |
38 |
| -template <class CTYPE> |
39 |
| -void ones_kernel(Tensor& out) { |
40 |
| - // Create pointer over `out` data with `CTYPE`. |
41 |
| - auto data_out = out.mutable_data_ptr<CTYPE>(); |
42 |
| - |
43 |
| - // Set each element of the tensor to the "1" value for the type. |
44 |
| - for (size_t i = 0; i < out.numel(); i++) { |
45 |
| - data_out[i] = static_cast<CTYPE>(1); |
46 |
| - } |
47 |
| -}; |
48 |
| - |
49 |
| -} // namespace |
50 |
| - |
51 |
| -/** |
52 |
| - * `ones_out` implementation. |
53 |
| - */ |
54 | 15 | Tensor& ones_out(RuntimeContext& ctx, IntArrayRef size, Tensor& out) {
|
55 | 16 | (void)ctx;
|
56 |
| - size_check(size, out.sizes()); |
57 |
| - |
58 |
| -#define ONES_OUT(ctype, dtype) \ |
59 |
| - case ScalarType::dtype: \ |
60 |
| - ones_kernel<ctype>(out); \ |
61 |
| - break; |
62 | 17 |
|
63 |
| - switch (out.scalar_type()) { |
64 |
| - ET_FORALL_REAL_TYPES_AND(Bool, ONES_OUT) |
65 |
| - default: |
66 |
| - ET_CHECK_MSG( |
67 |
| - false, |
68 |
| - "out tensor should be a real or bool dtype, but got %" PRId8, |
69 |
| - static_cast<int8_t>(out.scalar_type())); |
70 |
| - } |
71 |
| -#undef ONES_OUT |
| 18 | + // Resize for dynamic shape |
| 19 | + ET_KERNEL_CHECK( |
| 20 | + ctx, resize_tensor(out, size) == Error::Ok, InvalidArgument, out); |
| 21 | + |
| 22 | + ScalarType out_type = out.scalar_type(); |
| 23 | + ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE, [&] { |
| 24 | + auto out_data = out.mutable_data_ptr<CTYPE>(); |
| 25 | + for (size_t i = 0; i < out.numel(); i++) { |
| 26 | + out_data[i] = static_cast<CTYPE>(1); |
| 27 | + } |
| 28 | + }); |
72 | 29 |
|
73 | 30 | return out;
|
74 | 31 | }
|
|
0 commit comments