Skip to content

Commit e89b744

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Cleanup log_softmax & softmax (#791)
Summary: Pull Request resolved: #791 Refactors helper functions for checking arguments, to be reused by optimized `log_softmax` ghstack-source-id: 203560636 exported-using-ghexport Reviewed By: cbilgin Differential Revision: D50130021 fbshipit-source-id: f964002a195563aef7d1b573b453da5423585647
1 parent 20df1a2 commit e89b744

File tree

5 files changed

+46
-71
lines changed

5 files changed

+46
-71
lines changed

kernels/portable/cpu/op_log_softmax.cpp

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <cmath>
1010

11+
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
1112
#include <executorch/kernels/portable/cpu/util/functional_util.h>
1213
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1314
#include <executorch/runtime/kernel/kernel_includes.h>
@@ -18,37 +19,6 @@ namespace native {
1819

1920
using Tensor = exec_aten::Tensor;
2021

21-
namespace {
22-
23-
void check_preconditions(
24-
const Tensor& in,
25-
int64_t dim,
26-
bool half_to_float,
27-
Tensor& out) {
28-
// Ensure half_to_float is not true
29-
ET_CHECK_MSG(
30-
!half_to_float,
31-
"log_softmax with half to float conversion is not supported on CPU");
32-
// Check both in and out are of the same dtype
33-
ET_CHECK_SAME_DTYPE2(in, out);
34-
// Check both in and out have the same number of dimensions
35-
ET_CHECK_MSG(
36-
in.dim() == out.dim(),
37-
"in.dim() %zd!= out.dim() %zd",
38-
in.dim(),
39-
out.dim());
40-
// Ensure dim is valid
41-
if (in.dim() == 0) {
42-
ET_CHECK_MSG(dim == 0 || dim == -1, "dim must be 0 or -1 for 0-D tensor");
43-
} else {
44-
ET_CHECK_VALID_DIM(dim, in.dim());
45-
}
46-
ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(in);
47-
ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(out);
48-
}
49-
50-
} // namespace
51-
5222
Tensor& log_softmax_out(
5323
RuntimeContext& ctx,
5424
const Tensor& in,
@@ -57,14 +27,13 @@ Tensor& log_softmax_out(
5727
Tensor& out) {
5828
(void)ctx;
5929

60-
check_preconditions(in, dim, half_to_float, out);
30+
check_log_softmax_args(in, dim, half_to_float, out);
6131

62-
Error err = resize_tensor(out, in.sizes());
63-
ET_CHECK_MSG(
64-
err == Error::Ok, "Failed to resize out tensor in log_softmax_out");
32+
ET_KERNEL_CHECK(
33+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
6534

6635
// Adjust for negative dim
67-
dim = dim < 0 ? dim + in.dim() : dim;
36+
dim = dim < 0 ? dim + nonzero_dim(in) : dim;
6837

6938
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "log_softmax", CTYPE, [&]() {
7039
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();

kernels/portable/cpu/op_softmax.cpp

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <cmath>
1010

11+
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
1112
#include <executorch/kernels/portable/cpu/util/functional_util.h>
1213
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1314
#include <executorch/runtime/kernel/kernel_includes.h>
@@ -18,37 +19,6 @@ namespace native {
1819

1920
using Tensor = exec_aten::Tensor;
2021

21-
namespace {
22-
23-
void check_preconditions(
24-
const Tensor& in,
25-
int64_t dim,
26-
bool half_to_float,
27-
Tensor& out) {
28-
// Ensure half_to_float is not true
29-
ET_CHECK_MSG(
30-
!half_to_float,
31-
"softmax with half to float conversion is not supported on CPU");
32-
// Check both in and out are of the same dtype
33-
ET_CHECK_SAME_DTYPE2(in, out);
34-
// Check both in and out have the same number of dimensions
35-
ET_CHECK_MSG(
36-
in.dim() == out.dim(),
37-
"in.dim() %zd!= out.dim() %zd",
38-
in.dim(),
39-
out.dim());
40-
// Ensure dim is valid
41-
if (in.dim() == 0) {
42-
ET_CHECK(dim == 0 || dim == -1);
43-
} else {
44-
ET_CHECK_VALID_DIM(dim, in.dim());
45-
}
46-
ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(in);
47-
ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(out);
48-
}
49-
50-
} // namespace
51-
5222
Tensor& softmax_out(
5323
RuntimeContext& ctx,
5424
const Tensor& in,
@@ -57,13 +27,13 @@ Tensor& softmax_out(
5727
Tensor& out) {
5828
(void)ctx;
5929

60-
check_preconditions(in, dim, half_to_float, out);
30+
check_softmax_args(in, dim, half_to_float, out);
6131

62-
Error err = resize_tensor(out, in.sizes());
63-
ET_CHECK_MSG(err == Error::Ok, "Failed to resize out tensor in softmax_out");
32+
ET_KERNEL_CHECK(
33+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
6434

6535
// Adjust for negative dim
66-
dim = dim < 0 ? dim + in.dim() : dim;
36+
dim = dim < 0 ? dim + nonzero_dim(in) : dim;
6737

6838
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "softmax", CTYPE, [&]() {
6939
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();

kernels/portable/cpu/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ _ATEN_OPS = (
434434
name = "op_log_softmax",
435435
deps = [
436436
":vec_ops",
437+
"//executorch/kernels/portable/cpu/util:activation_ops_util",
437438
"//executorch/kernels/portable/cpu/util:functional_util",
438439
"//executorch/kernels/portable/cpu/util:reduce_util",
439440
],
@@ -693,6 +694,7 @@ _ATEN_OPS = (
693694
name = "op_softmax",
694695
deps = [
695696
":vec_ops",
697+
"//executorch/kernels/portable/cpu/util:activation_ops_util",
696698
"//executorch/kernels/portable/cpu/util:functional_util",
697699
"//executorch/kernels/portable/cpu/util:reduce_util",
698700
],

kernels/portable/cpu/util/activation_ops_util.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,27 @@ bool check_gelu_args(const Tensor& in, string_view approximate, Tensor& out) {
2323
return true;
2424
}
2525

26+
bool check_log_softmax_args(
27+
const Tensor& in,
28+
int64_t dim,
29+
bool half_to_float,
30+
Tensor& out) {
31+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
32+
!half_to_float, "half to float conversion is not supported on CPU");
33+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
34+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
35+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
36+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
37+
return true;
38+
}
39+
40+
bool check_softmax_args(
41+
const Tensor& in,
42+
int64_t dim,
43+
bool half_to_float,
44+
Tensor& out) {
45+
return check_log_softmax_args(in, dim, half_to_float, out);
46+
}
47+
2648
} // namespace executor
2749
} // namespace torch

kernels/portable/cpu/util/activation_ops_util.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,17 @@ namespace executor {
1515

1616
bool check_gelu_args(const Tensor& in, string_view approximate, Tensor& out);
1717

18+
bool check_log_softmax_args(
19+
const Tensor& in,
20+
int64_t dim,
21+
bool half_to_float,
22+
Tensor& out);
23+
24+
bool check_softmax_args(
25+
const Tensor& in,
26+
int64_t dim,
27+
bool half_to_float,
28+
Tensor& out);
29+
1830
} // namespace executor
1931
} // namespace torch

0 commit comments

Comments
 (0)