Skip to content

Commit 8aa5db4

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add ceil op
Reviewed By: SS-JIA, kirklandsign Differential Revision: D48380140 fbshipit-source-id: 79da1b45f44dee2fd18f328d4587f469c4f4f898
1 parent e7aebe9 commit 8aa5db4

File tree

6 files changed

+77
-0
lines changed

6 files changed

+77
-0
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181

8282
- op: cat.out
8383

84+
- op: ceil.out
85+
8486
- op: clamp_min.out
8587

8688
- op: clamp.out

kernels/portable/cpu/op_ceil.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <cmath>
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace native {
16+
17+
using exec_aten::Tensor;
18+
19+
Tensor& ceil_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
20+
return internal::unary_ufunc_real(std::ceil, ctx, in, out);
21+
}
22+
23+
} // namespace native
24+
} // namespace executor
25+
} // namespace torch

kernels/portable/cpu/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ _ATEN_OPS = (
178178
"//executorch/kernels/portable/cpu/util:copy_ops_util",
179179
],
180180
),
181+
op_target(
182+
name = "op_ceil",
183+
deps = [
184+
"//executorch/kernels/portable/cpu/pattern:pattern",
185+
],
186+
),
181187
op_target(
182188
name = "op_clamp",
183189
deps = [

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@
177177
- arg_meta: null
178178
kernel_name: torch::executor::cat_out
179179

180+
- op: ceil.out
181+
kernels:
182+
- arg_meta: null
183+
kernel_name: torch::executor::ceil_out
184+
180185
- op: clamp.out
181186
cpp_no_default_args: ['min']
182187
kernels:

kernels/test/op_ceil_test.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14+
15+
#include <gtest/gtest.h>
16+
17+
using namespace ::testing;
18+
using exec_aten::ScalarType;
19+
using exec_aten::Tensor;
20+
using torch::executor::testing::TensorFactory;
21+
22+
Tensor& op_ceil_out(const Tensor& self, Tensor& out) {
23+
exec_aten::RuntimeContext context{};
24+
return torch::executor::aten::ceil_outf(context, self, out);
25+
}
26+
27+
TEST(OpCeilTest, SanityCheck) {
28+
TensorFactory<ScalarType::Float> tf;
29+
30+
Tensor in = tf.make({1, 7}, {-3.0, -2.99, -1.01, 0.0, 1.01, 2.99, 3.0});
31+
Tensor out = tf.zeros({1, 7});
32+
Tensor expected = tf.make({1, 7}, {-3.0, -2.0, -1.0, 0.0, 2.0, 3.0, 3.0});
33+
34+
Tensor ret = op_ceil_out(in, out);
35+
36+
EXPECT_TENSOR_EQ(out, ret);
37+
EXPECT_TENSOR_EQ(out, expected);
38+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def define_common_targets():
179179
_common_op_test("op_bitwise_xor_test", ["aten", "portable"])
180180
_common_op_test("op_bmm_test", ["aten", "portable", "optimized"])
181181
_common_op_test("op_cat_test", ["aten", "portable"])
182+
_common_op_test("op_ceil_test", ["aten", "portable"])
182183
_common_op_test("op_clamp_test", ["aten", "portable"])
183184
_common_op_test("op_clone_test", ["aten", "portable"])
184185
_common_op_test("op_constant_pad_nd_test", ["aten", "portable"])

0 commit comments

Comments
 (0)