@@ -44,23 +44,24 @@ Tensor& mean_dim_out(
44
44
InvalidArgument,
45
45
out);
46
46
47
- ET_SWITCH_REALHB_TYPES (in.scalar_type (), ctx, " mean.out" , CTYPE_IN, [&] {
48
- ET_SWITCH_FLOATH_TYPES (out.scalar_type (), ctx, " mean.out" , CTYPE_OUT, [&] {
49
- CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
50
- const size_t num = get_reduced_dim_product (in, dim_list);
51
- for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
52
- CTYPE_OUT sum = 0 ;
53
- if (in.numel () > 0 ) {
54
- sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
55
- [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
56
- [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
57
- in,
58
- dim_list,
59
- out_ix);
60
- }
61
- out_data[out_ix] = sum / static_cast <float >(num);
62
- }
63
- });
47
+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " mean.out" , CTYPE_IN, [&] {
48
+ ET_SWITCH_FLOATHBF16_TYPES (
49
+ out.scalar_type (), ctx, " mean.out" , CTYPE_OUT, [&] {
50
+ CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
51
+ const size_t num = get_reduced_dim_product (in, dim_list);
52
+ for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
53
+ CTYPE_OUT sum = 0 ;
54
+ if (in.numel () > 0 ) {
55
+ sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
56
+ [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
57
+ [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
58
+ in,
59
+ dim_list,
60
+ out_ix);
61
+ }
62
+ out_data[out_ix] = sum / static_cast <float >(num);
63
+ }
64
+ });
64
65
});
65
66
66
67
return out;
0 commit comments