14
14
#include < cmath>
15
15
#include < type_traits>
16
16
17
+ #include < executorch/kernels/portable/cpu/util/activation_ops_util.h>
17
18
#include < executorch/runtime/kernel/kernel_includes.h>
18
19
19
20
// `_log_softmax_out` Applies the Log_Softmax function to an n-dimensional input
@@ -32,6 +33,11 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
32
33
const IN_T* __restrict__ input_data_base = input.data_ptr <IN_T>();
33
34
OUT_T* __restrict__ output_data_base = out.data_ptr <OUT_T>();
34
35
36
+ if (input.dim () == 0 ) {
37
+ output_data_base[0 ] = 0 ;
38
+ return ;
39
+ }
40
+
35
41
int64_t dim_size = input.size (dim);
36
42
37
43
int64_t outer_size = 1 ;
@@ -116,39 +122,6 @@ void log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) {
116
122
}
117
123
} // namespace
118
124
119
- void opt_log_soft_max_check_preconditions (
120
- const Tensor& self,
121
- int64_t dim,
122
- bool half_to_float,
123
- Tensor& out) {
124
- // Ensure half_to_float is not true
125
- ET_CHECK_MSG (
126
- !half_to_float,
127
- " softmax with half to float conversion is not supported on CPU" );
128
- // Ensure self has value
129
- ET_CHECK_MSG (self.numel () > 0 , " self.numel() %zd <= 0" , self.numel ());
130
- // Ensure dim is valid
131
- ET_CHECK_MSG (
132
- dim >= 0 && dim < self.dim (),
133
- " dim %" PRId64 " >= 0 && dim %" PRId64 " < self.dim() %zd" ,
134
- dim,
135
- dim,
136
- self.dim ());
137
- // Ensure self and out have the same shape
138
- ET_CHECK_SAME_SHAPE2 (self, out);
139
- // Ensure self and out are float
140
- auto out_scalar_type = out.scalar_type ();
141
- ET_CHECK_MSG (
142
- out_scalar_type == ScalarType::Float,
143
- " out.scalar_type() %" PRId8 " is not Float" ,
144
- static_cast <int8_t >(out_scalar_type));
145
- auto input_scalar_type = self.scalar_type ();
146
- ET_CHECK_MSG (
147
- input_scalar_type == ScalarType::Float,
148
- " self.scalar_type() %" PRId8 " is not Float" ,
149
- static_cast <int8_t >(input_scalar_type));
150
- }
151
-
152
125
// _log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out)
153
126
// -> Tensor(a!)
154
127
Tensor& opt_log_softmax_out (
@@ -158,16 +131,14 @@ Tensor& opt_log_softmax_out(
158
131
bool half_to_float,
159
132
Tensor& out) {
160
133
(void )context;
161
- dim = dim < 0 ? dim + self.dim () : dim;
162
- Tensor::SizesType expected_output_size[16 ];
163
- for (size_t i = 0 ; i < out.dim (); ++i) {
164
- expected_output_size[i] = self.size (i);
165
- }
166
- auto error = resize_tensor (
167
- out, {expected_output_size, static_cast <size_t >(out.dim ())});
168
- // TODO: Construct error message with requested output sizes.
169
- ET_CHECK_MSG (error == Error::Ok, " Failed to resize output tensor." );
170
- opt_log_soft_max_check_preconditions (self, dim, half_to_float, out);
134
+
135
+ check_log_softmax_args (self, dim, half_to_float, out);
136
+
137
+ ET_KERNEL_CHECK (
138
+ ctx, resize_tensor (out, self.sizes ()) == Error::Ok, InvalidArgument, out);
139
+
140
+ dim = dim < 0 ? dim + nonzero_dim (self) : dim;
141
+
171
142
auto out_scalar_type = out.scalar_type ();
172
143
switch (out_scalar_type) {
173
144
// TODO: support Double as well
0 commit comments