Skip to content

Commit e7aebe9

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Pattern: unary_ufunc_real
Reviewed By: kirklandsign Differential Revision: D48380141 fbshipit-source-id: 4a47b9aa41e7010cf81fa90aa58a7ae4735f87ed
1 parent bf33da4 commit e7aebe9

File tree

5 files changed

+62
-23
lines changed

5 files changed

+62
-23
lines changed

kernels/portable/cpu/op_floor.cpp

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cmath>
10-
11-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
9+
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
1210
#include <executorch/runtime/kernel/kernel_includes.h>
13-
#include <executorch/runtime/platform/assert.h>
11+
#include <cmath>
1412

1513
namespace torch {
1614
namespace executor {
@@ -19,24 +17,7 @@ namespace native {
1917
using exec_aten::Tensor;
2018

2119
Tensor& floor_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
22-
(void)ctx;
23-
24-
// Resize for dynamic shape
25-
auto error = resize_tensor(out, in.sizes());
26-
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
27-
ET_CHECK_SAME_SHAPE_AND_DTYPE2(in, out);
28-
29-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "floor", CTYPE, [&] {
30-
apply_unary_map_fn(
31-
[](const CTYPE val_in) {
32-
return static_cast<CTYPE>(std::floor(val_in));
33-
},
34-
in.const_data_ptr<CTYPE>(),
35-
out.mutable_data_ptr<CTYPE>(),
36-
in.numel());
37-
});
38-
39-
return out;
20+
return internal::unary_ufunc_real(std::floor, ctx, in, out);
4021
}
4122

4223
} // namespace native

kernels/portable/cpu/pattern/pattern.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ namespace executor {
5151
namespace native {
5252
namespace internal {
5353

54+
/**
55+
* Implements an op pattern for ops that take a single input tensor of any
56+
* real dtye, no additional arguments, and outputs a tensor of the same size
57+
* and dtype. The function fn specifies the math operation which is applied to
58+
* the input tensor element-wise.
59+
*/
60+
Tensor& unary_ufunc_real(
61+
FunctionRef<double(double)> fn,
62+
RuntimeContext& ctx,
63+
const Tensor& in,
64+
Tensor& out);
65+
5466
/**
5567
* Implements an op pattern for ops that take a single input tensor of any
5668
* realb dtye (real and boolean), no additional arguments, and outputs a

kernels/portable/cpu/pattern/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def define_common_targets():
1313
"binary_ufunc_realb_realb_to_realb_logical.cpp",
1414
"unary_ufunc_realb_to_bool.cpp",
1515
"unary_ufunc_realb_to_float.cpp",
16+
"unary_ufunc_real.cpp",
1617
],
1718
exported_headers = [
1819
"pattern.h",
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
10+
#include <executorch/kernels/portable/cpu/util/functional_util.h>
11+
#include <executorch/runtime/core/function_ref.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
namespace internal {
18+
19+
Tensor& unary_ufunc_real(
20+
FunctionRef<double(double)> fn,
21+
RuntimeContext& ctx,
22+
const Tensor& in,
23+
Tensor& out) {
24+
(void)ctx;
25+
26+
// Resize for dynamic shape
27+
auto error = resize_tensor(out, in.sizes());
28+
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
29+
ET_CHECK_SAME_SHAPE_AND_DTYPE2(in, out);
30+
31+
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] {
32+
apply_unary_map_fn(
33+
[fn](const CTYPE val_in) { return static_cast<CTYPE>(fn(val_in)); },
34+
in.const_data_ptr<CTYPE>(),
35+
out.mutable_data_ptr<CTYPE>(),
36+
in.numel());
37+
});
38+
39+
return out;
40+
}
41+
42+
} // namespace internal
43+
} // namespace native
44+
} // namespace executor
45+
} // namespace torch

kernels/portable/cpu/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ _ATEN_OPS = (
287287
op_target(
288288
name = "op_floor",
289289
deps = [
290-
"//executorch/kernels/portable/cpu/util:functional_util",
290+
"//executorch/kernels/portable/cpu/pattern:pattern",
291291
],
292292
),
293293
op_target(

0 commit comments

Comments
 (0)