@@ -60,7 +60,7 @@ int prepare_data(
60
60
return num_axis_dims;
61
61
}
62
62
63
- Tensor& mean_dim_out (
63
+ Tensor& mean_out (
64
64
KernelRuntimeContext& ctx,
65
65
const Tensor& in,
66
66
optional<ArrayRef<int64_t >> dim_list,
@@ -169,29 +169,32 @@ Tensor& mean_dim_out(
169
169
InvalidArgument,
170
170
out);
171
171
172
- ET_SWITCH_REALHB_TYPES (in.scalar_type (), ctx, " mean.out" , CTYPE_IN, [&] {
173
- ET_SWITCH_FLOATH_TYPES (
174
- out.scalar_type (), ctx, " mean.out" , CTYPE_OUT, [&] {
175
- CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
176
- const size_t num =
177
- torch::executor::get_reduced_dim_product (in, dim_list);
178
- for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
179
- CTYPE_OUT sum = 0 ;
180
- if (in.numel () > 0 ) {
181
- sum = torch::executor::
182
- map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
183
- [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
184
- [](CTYPE_OUT outv, CTYPE_OUT acc) {
185
- return acc + outv;
186
- },
187
- in,
188
- dim_list,
189
- out_ix);
190
- }
191
- out_data[out_ix] = sum / static_cast <float >(num);
192
- }
193
- });
194
- });
172
+ ET_SWITCH_REALHBBF16_TYPES (
173
+ in.scalar_type (), ctx, " mean.out" , CTYPE_IN, [&] {
174
+ ET_SWITCH_FLOATHBF16_TYPES (
175
+ out.scalar_type (), ctx, " mean.out" , CTYPE_OUT, [&] {
176
+ CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
177
+ const size_t num =
178
+ torch::executor::get_reduced_dim_product (in, dim_list);
179
+ for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
180
+ CTYPE_OUT sum = 0 ;
181
+ if (in.numel () > 0 ) {
182
+ sum = torch::executor::
183
+ map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
184
+ [](CTYPE_IN v) {
185
+ return static_cast <CTYPE_OUT>(v);
186
+ },
187
+ [](CTYPE_OUT outv, CTYPE_OUT acc) {
188
+ return acc + outv;
189
+ },
190
+ in,
191
+ dim_list,
192
+ out_ix);
193
+ }
194
+ out_data[out_ix] = sum / static_cast <float >(num);
195
+ }
196
+ });
197
+ });
195
198
}
196
199
197
200
return out;
0 commit comments