Skip to content

Commit 469b5ca

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix optimized log_softmax (#792)
Summary: Pull Request resolved: #792 ghstack-source-id: 203560708 exported-using-ghexport Reviewed By: cbilgin Differential Revision: D50130020 fbshipit-source-id: 9b442346aceafb42deccb18e6f99d93aa2306cb2
1 parent e89b744 commit 469b5ca

File tree

2 files changed

+18
-44
lines changed

2 files changed

+18
-44
lines changed

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <cmath>
1515
#include <type_traits>
1616

17+
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
1718
#include <executorch/runtime/kernel/kernel_includes.h>
1819

1920
// `_log_softmax_out` Applies the Log_Softmax function to an n-dimensional input
@@ -32,6 +33,11 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
3233
const IN_T* __restrict__ input_data_base = input.data_ptr<IN_T>();
3334
OUT_T* __restrict__ output_data_base = out.data_ptr<OUT_T>();
3435

36+
if (input.dim() == 0) {
37+
output_data_base[0] = 0;
38+
return;
39+
}
40+
3541
int64_t dim_size = input.size(dim);
3642

3743
int64_t outer_size = 1;
@@ -116,39 +122,6 @@ void log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) {
116122
}
117123
} // namespace
118124

119-
void opt_log_soft_max_check_preconditions(
120-
const Tensor& self,
121-
int64_t dim,
122-
bool half_to_float,
123-
Tensor& out) {
124-
// Ensure half_to_float is not true
125-
ET_CHECK_MSG(
126-
!half_to_float,
127-
"softmax with half to float conversion is not supported on CPU");
128-
// Ensure self has value
129-
ET_CHECK_MSG(self.numel() > 0, "self.numel() %zd <= 0", self.numel());
130-
// Ensure dim is valid
131-
ET_CHECK_MSG(
132-
dim >= 0 && dim < self.dim(),
133-
"dim %" PRId64 " >= 0 && dim %" PRId64 " < self.dim() %zd",
134-
dim,
135-
dim,
136-
self.dim());
137-
// Ensure self and out have the same shape
138-
ET_CHECK_SAME_SHAPE2(self, out);
139-
// Ensure self and out are float
140-
auto out_scalar_type = out.scalar_type();
141-
ET_CHECK_MSG(
142-
out_scalar_type == ScalarType::Float,
143-
"out.scalar_type() %" PRId8 " is not Float",
144-
static_cast<int8_t>(out_scalar_type));
145-
auto input_scalar_type = self.scalar_type();
146-
ET_CHECK_MSG(
147-
input_scalar_type == ScalarType::Float,
148-
"self.scalar_type() %" PRId8 " is not Float",
149-
static_cast<int8_t>(input_scalar_type));
150-
}
151-
152125
// _log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out)
153126
// -> Tensor(a!)
154127
Tensor& opt_log_softmax_out(
@@ -158,16 +131,14 @@ Tensor& opt_log_softmax_out(
158131
bool half_to_float,
159132
Tensor& out) {
160133
(void)context;
161-
dim = dim < 0 ? dim + self.dim() : dim;
162-
Tensor::SizesType expected_output_size[16];
163-
for (size_t i = 0; i < out.dim(); ++i) {
164-
expected_output_size[i] = self.size(i);
165-
}
166-
auto error = resize_tensor(
167-
out, {expected_output_size, static_cast<size_t>(out.dim())});
168-
// TODO: Construct error message with requested output sizes.
169-
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
170-
opt_log_soft_max_check_preconditions(self, dim, half_to_float, out);
134+
135+
check_log_softmax_args(self, dim, half_to_float, out);
136+
137+
ET_KERNEL_CHECK(
138+
ctx, resize_tensor(out, self.sizes()) == Error::Ok, InvalidArgument, out);
139+
140+
dim = dim < 0 ? dim + nonzero_dim(self) : dim;
141+
171142
auto out_scalar_type = out.scalar_type();
172143
switch (out_scalar_type) {
173144
// TODO: support Double as well

kernels/optimized/cpu/targets.bzl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@ _OPTIMIZED_ATEN_OPS = (
4141
op_target(
4242
name = "op_log_softmax",
4343
deps = select({
44-
"DEFAULT": [],
44+
"DEFAULT": [
45+
"//executorch/kernels/portable/cpu/util:activation_ops_util",
46+
],
4547
"ovr_config//runtime:fbcode-arm64": [
48+
"//executorch/kernels/portable/cpu/util:activation_ops_util",
4649
"fbsource//third-party/sleef:sleef_arm",
4750
],
4851
}),

0 commit comments

Comments
 (0)