Skip to content

Commit 16b633b

Browse files
Add ops: max.unary_out & min.unary_out
Differential Revision: D64986580 Pull Request resolved: #6500
1 parent 800fc27 commit 16b633b

File tree

7 files changed

+224
-8
lines changed

7 files changed

+224
-8
lines changed

kernels/aten/functions.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,12 +249,16 @@
249249

250250
- op: max.unary_out
251251

252+
- op: max.unary_out
253+
252254
- op: maximum.out
253255

254256
- op: mean.out
255257

256258
- op: min.dim_min
257259

260+
- op: min.unary_out
261+
258262
- op: minimum.out
259263

260264
- op: mm.out

kernels/portable/cpu/op_max.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,22 @@
99
#include <cmath>
1010
#include <tuple>
1111

12-
#include <executorch/kernels/portable/cpu/util/index_util.h>
1312
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1413
#include <executorch/runtime/kernel/kernel_includes.h>
1514
#include <executorch/runtime/platform/assert.h>
1615

1716
namespace torch {
1817
namespace executor {
1918
namespace native {
19+
namespace {
20+
21+
template <typename CTYPE>
22+
constexpr CTYPE lower_bound() {
23+
using lim = std::numeric_limits<CTYPE>;
24+
return lim::has_infinity ? -lim::infinity() : lim::lowest();
25+
}
26+
27+
} // namespace
2028

2129
using ScalarType = exec_aten::ScalarType;
2230
using SizesType = exec_aten::SizesType;
@@ -94,6 +102,44 @@ std::tuple<Tensor&, Tensor&> max_out(
94102
return {max, max_indices};
95103
}
96104

105+
Tensor&
106+
max_unary_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
107+
(void)ctx;
108+
109+
ET_KERNEL_CHECK(
110+
ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
111+
112+
ET_KERNEL_CHECK(
113+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
114+
115+
ScalarType in_type = in.scalar_type();
116+
ScalarType out_type = out.scalar_type();
117+
118+
ET_KERNEL_CHECK(ctx, canCast(in_type, out_type), InvalidArgument, out);
119+
120+
constexpr auto name = "max.unary_out";
121+
122+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
123+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
124+
const auto data_in = in.const_data_ptr<CTYPE_IN>();
125+
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
126+
data_out[0] = lower_bound<CTYPE_OUT>();
127+
for (auto i = 0; i < in.numel(); ++i) {
128+
CTYPE_OUT val = static_cast<CTYPE_OUT>(data_in[i]);
129+
if (std::isnan(val)) {
130+
data_out[0] = val;
131+
break;
132+
}
133+
if (val > data_out[0]) {
134+
data_out[0] = val;
135+
}
136+
}
137+
});
138+
});
139+
140+
return out;
141+
}
142+
97143
} // namespace native
98144
} // namespace executor
99145
} // namespace torch

kernels/portable/cpu/op_min.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,22 @@
99
#include <cmath>
1010
#include <tuple>
1111

12-
#include <executorch/kernels/portable/cpu/util/index_util.h>
1312
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1413
#include <executorch/runtime/kernel/kernel_includes.h>
1514
#include <executorch/runtime/platform/assert.h>
1615

1716
namespace torch {
1817
namespace executor {
1918
namespace native {
19+
namespace {
20+
21+
template <typename CTYPE>
22+
constexpr CTYPE upper_bound() {
23+
using lim = std::numeric_limits<CTYPE>;
24+
return lim::has_infinity ? lim::infinity() : lim::max();
25+
}
26+
27+
} // namespace
2028

2129
using ScalarType = exec_aten::ScalarType;
2230
using SizesType = exec_aten::SizesType;
@@ -94,6 +102,44 @@ std::tuple<Tensor&, Tensor&> min_out(
94102
return {min, min_indices};
95103
}
96104

105+
Tensor&
106+
min_unary_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
107+
(void)ctx;
108+
109+
ET_KERNEL_CHECK(
110+
ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
111+
112+
ET_KERNEL_CHECK(
113+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
114+
115+
ScalarType in_type = in.scalar_type();
116+
ScalarType out_type = out.scalar_type();
117+
118+
ET_KERNEL_CHECK(ctx, canCast(in_type, out_type), InvalidArgument, out);
119+
120+
constexpr auto name = "min.unary_out";
121+
122+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
123+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
124+
const auto data_in = in.const_data_ptr<CTYPE_IN>();
125+
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
126+
data_out[0] = upper_bound<CTYPE_OUT>();
127+
for (auto i = 0; i < in.numel(); ++i) {
128+
CTYPE_OUT val = static_cast<CTYPE_OUT>(data_in[i]);
129+
if (std::isnan(val)) {
130+
data_out[0] = val;
131+
break;
132+
}
133+
if (val < data_out[0]) {
134+
data_out[0] = val;
135+
}
136+
}
137+
});
138+
});
139+
140+
return out;
141+
}
142+
97143
} // namespace native
98144
} // namespace executor
99145
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,11 @@
552552
- arg_meta: null
553553
kernel_name: torch::executor::max_out
554554

555+
- op: max.unary_out
556+
kernels:
557+
- arg_meta: null
558+
kernel_name: torch::executor::max_unary_out
559+
555560
- op: maximum.out
556561
kernels:
557562
- arg_meta: null
@@ -572,6 +577,11 @@
572577
- arg_meta: null
573578
kernel_name: torch::executor::min_out
574579

580+
- op: min.unary_out
581+
kernels:
582+
- arg_meta: null
583+
kernel_name: torch::executor::min_unary_out
584+
575585
- op: minimum.out
576586
kernels:
577587
- arg_meta: null

kernels/test/op_max_test.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,64 @@ void OpMaxOutTest::test_max_out_dtype<ScalarType::Bool>() {
222222
// clang-format on
223223
}
224224

225+
class OpMaxUnaryOutTest : public OperatorTest {
226+
protected:
227+
Tensor& op_max_unary_out(const Tensor& self, Tensor& out) {
228+
return torch::executor::aten::max_outf(context_, self, out);
229+
}
230+
231+
template <ScalarType IN_DTYPE>
232+
void test_max_unary_out_dtype() {
233+
TensorFactory<IN_DTYPE> tf_in;
234+
TensorFactory<ScalarType::Float> tf_out;
235+
Tensor input = tf_in.make({2, 3}, {0, 1, 2, 4, 4, 2});
236+
Tensor out = tf_out.zeros({});
237+
Tensor expected = tf_out.make({}, {4});
238+
op_max_unary_out(input, out);
239+
EXPECT_TENSOR_CLOSE(out, expected);
240+
}
241+
242+
template <typename CTYPE, ScalarType IN_DTYPE>
243+
void test_max_unary_out_empty_integer() {
244+
TensorFactory<IN_DTYPE> tf_in;
245+
Tensor input = tf_in.make({2, 0}, {});
246+
Tensor out = tf_in.zeros({});
247+
Tensor expected = tf_in.make({}, {std::numeric_limits<CTYPE>::lowest()});
248+
op_max_unary_out(input, out);
249+
EXPECT_TENSOR_CLOSE(out, expected);
250+
}
251+
252+
template <typename CTYPE, ScalarType IN_DTYPE>
253+
void test_max_unary_out_empty_floating() {
254+
TensorFactory<IN_DTYPE> tf_in;
255+
Tensor input = tf_in.make({2, 0}, {});
256+
Tensor out = tf_in.zeros({});
257+
Tensor expected = tf_in.make({}, {-INFINITY});
258+
op_max_unary_out(input, out);
259+
EXPECT_TENSOR_CLOSE(out, expected);
260+
}
261+
};
262+
263+
TEST_F(OpMaxUnaryOutTest, AllRealHBF16InputFloatOutputPasses) {
264+
#define TEST_ENTRY(ctype, dtype) test_max_unary_out_dtype<ScalarType::dtype>();
265+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
266+
#undef TEST_ENTRY
267+
}
268+
269+
TEST_F(OpMaxUnaryOutTest, EmptyIntegerInput) {
270+
#define TEST_ENTRY(ctype, dtype) \
271+
test_max_unary_out_empty_integer<ctype, ScalarType::dtype>();
272+
ET_FORALL_INT_TYPES(TEST_ENTRY);
273+
#undef TEST_ENTRY
274+
}
275+
276+
TEST_F(OpMaxUnaryOutTest, EmptyFloatingInput) {
277+
#define TEST_ENTRY(ctype, dtype) \
278+
test_max_unary_out_empty_floating<ctype, ScalarType::dtype>();
279+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
280+
#undef TEST_ENTRY
281+
}
282+
225283
TEST_F(OpMaxOutTest, MismatchedDimensionsDies) {
226284
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
227285
GTEST_SKIP() << "ATen kernel test fails";

kernels/test/op_min_test.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,64 @@ EXPECT_TENSOR_EQ(min_indices, tf_long.make(
218218
// clang-format on
219219
}
220220

221+
class OpMinUnaryOutTest : public OperatorTest {
222+
protected:
223+
Tensor& op_min_unary_out(const Tensor& self, Tensor& out) {
224+
return torch::executor::aten::min_outf(context_, self, out);
225+
}
226+
227+
template <ScalarType IN_DTYPE>
228+
void test_min_unary_out_dtype() {
229+
TensorFactory<IN_DTYPE> tf_in;
230+
TensorFactory<ScalarType::Float> tf_out;
231+
Tensor input = tf_in.make({2, 3}, {7, 1, 3, 4, 4, 2});
232+
Tensor out = tf_out.zeros({});
233+
Tensor expected = tf_out.make({}, {1});
234+
op_min_unary_out(input, out);
235+
EXPECT_TENSOR_CLOSE(out, expected);
236+
}
237+
238+
template <typename CTYPE, ScalarType IN_DTYPE>
239+
void test_min_unary_out_empty_integer() {
240+
TensorFactory<IN_DTYPE> tf_in;
241+
Tensor input = tf_in.make({2, 0}, {});
242+
Tensor out = tf_in.zeros({});
243+
Tensor expected = tf_in.make({}, {std::numeric_limits<CTYPE>::max()});
244+
op_min_unary_out(input, out);
245+
EXPECT_TENSOR_CLOSE(out, expected);
246+
}
247+
248+
template <typename CTYPE, ScalarType IN_DTYPE>
249+
void test_min_unary_out_empty_floating() {
250+
TensorFactory<IN_DTYPE> tf_in;
251+
Tensor input = tf_in.make({2, 0}, {});
252+
Tensor out = tf_in.zeros({});
253+
Tensor expected = tf_in.make({}, {INFINITY});
254+
op_min_unary_out(input, out);
255+
EXPECT_TENSOR_CLOSE(out, expected);
256+
}
257+
};
258+
259+
TEST_F(OpMinUnaryOutTest, AllRealHBF16InputFloatOutputPasses) {
260+
#define TEST_ENTRY(ctype, dtype) test_min_unary_out_dtype<ScalarType::dtype>();
261+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
262+
#undef TEST_ENTRY
263+
}
264+
265+
TEST_F(OpMinUnaryOutTest, EmptyIntegerInput) {
266+
#define TEST_ENTRY(ctype, dtype) \
267+
test_min_unary_out_empty_integer<ctype, ScalarType::dtype>();
268+
ET_FORALL_INT_TYPES(TEST_ENTRY);
269+
#undef TEST_ENTRY
270+
}
271+
272+
TEST_F(OpMinUnaryOutTest, EmptyFloatingInput) {
273+
#define TEST_ENTRY(ctype, dtype) \
274+
test_min_unary_out_empty_floating<ctype, ScalarType::dtype>();
275+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
276+
#undef TEST_ENTRY
277+
}
278+
221279
TEST_F(OpMinOutTest, MismatchedDimensionsDies) {
222280
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
223281
GTEST_SKIP() << "ATen kernel test fails";

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -785,9 +785,6 @@ ATEN_OPS = (
785785
op_target(
786786
name = "op_max",
787787
deps = [
788-
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
789-
"//executorch/runtime/core/exec_aten/util:tensor_util",
790-
"//executorch/kernels/portable/cpu/util:index_util",
791788
"//executorch/kernels/portable/cpu/util:reduce_util",
792789
],
793790
),
@@ -819,9 +816,6 @@ ATEN_OPS = (
819816
op_target(
820817
name = "op_min",
821818
deps = [
822-
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
823-
"//executorch/runtime/core/exec_aten/util:tensor_util",
824-
"//executorch/kernels/portable/cpu/util:index_util",
825819
"//executorch/kernels/portable/cpu/util:reduce_util",
826820
],
827821
),

0 commit comments

Comments
 (0)