Skip to content

[ET][Optimized] Fix optimized log_softmax #792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 14 additions & 43 deletions kernels/optimized/cpu/op_log_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cmath>
#include <type_traits>

#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

// `_log_softmax_out` Applies the Log_Softmax function to an n-dimensional input
Expand All @@ -32,6 +33,11 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
const IN_T* __restrict__ input_data_base = input.data_ptr<IN_T>();
OUT_T* __restrict__ output_data_base = out.data_ptr<OUT_T>();

if (input.dim() == 0) {
output_data_base[0] = 0;
return;
}

int64_t dim_size = input.size(dim);

int64_t outer_size = 1;
Expand Down Expand Up @@ -116,39 +122,6 @@ void log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) {
}
} // namespace

void opt_log_soft_max_check_preconditions(
const Tensor& self,
int64_t dim,
bool half_to_float,
Tensor& out) {
// Ensure half_to_float is not true
ET_CHECK_MSG(
!half_to_float,
"softmax with half to float conversion is not supported on CPU");
// Ensure self has value
ET_CHECK_MSG(self.numel() > 0, "self.numel() %zd <= 0", self.numel());
// Ensure dim is valid
ET_CHECK_MSG(
dim >= 0 && dim < self.dim(),
"dim %" PRId64 " >= 0 && dim %" PRId64 " < self.dim() %zd",
dim,
dim,
self.dim());
// Ensure self and out have the same shape
ET_CHECK_SAME_SHAPE2(self, out);
// Ensure self and out are float
auto out_scalar_type = out.scalar_type();
ET_CHECK_MSG(
out_scalar_type == ScalarType::Float,
"out.scalar_type() %" PRId8 " is not Float",
static_cast<int8_t>(out_scalar_type));
auto input_scalar_type = self.scalar_type();
ET_CHECK_MSG(
input_scalar_type == ScalarType::Float,
"self.scalar_type() %" PRId8 " is not Float",
static_cast<int8_t>(input_scalar_type));
}

// _log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out)
// -> Tensor(a!)
Tensor& opt_log_softmax_out(
Expand All @@ -158,16 +131,14 @@ Tensor& opt_log_softmax_out(
bool half_to_float,
Tensor& out) {
(void)context;
dim = dim < 0 ? dim + self.dim() : dim;
Tensor::SizesType expected_output_size[16];
for (size_t i = 0; i < out.dim(); ++i) {
expected_output_size[i] = self.size(i);
}
auto error = resize_tensor(
out, {expected_output_size, static_cast<size_t>(out.dim())});
// TODO: Construct error message with requested output sizes.
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
opt_log_soft_max_check_preconditions(self, dim, half_to_float, out);

check_log_softmax_args(self, dim, half_to_float, out);

ET_KERNEL_CHECK(
ctx, resize_tensor(out, self.sizes()) == Error::Ok, InvalidArgument, out);

dim = dim < 0 ? dim + nonzero_dim(self) : dim;

auto out_scalar_type = out.scalar_type();
switch (out_scalar_type) {
// TODO: support Double as well
Expand Down
5 changes: 4 additions & 1 deletion kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ _OPTIMIZED_ATEN_OPS = (
op_target(
name = "op_log_softmax",
deps = select({
"DEFAULT": [],
"DEFAULT": [
"//executorch/kernels/portable/cpu/util:activation_ops_util",
],
"ovr_config//runtime:fbcode-arm64": [
"//executorch/kernels/portable/cpu/util:activation_ops_util",
"fbsource//third-party/sleef:sleef_arm",
],
}),
Expand Down