Skip to content

Commit 682c636

Browse files
authored
link mean dim kernels
Differential Revision: D68845587 Pull Request resolved: #8053
1 parent a972e73 commit 682c636

File tree

1 file changed

+27
-24
lines changed

1 file changed

+27
-24
lines changed

backends/cadence/fusion_g3/operators/op_mean.cpp

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ int prepare_data(
6060
return num_axis_dims;
6161
}
6262

63-
Tensor& mean_dim_out(
63+
Tensor& mean_out(
6464
KernelRuntimeContext& ctx,
6565
const Tensor& in,
6666
optional<ArrayRef<int64_t>> dim_list,
@@ -169,29 +169,32 @@ Tensor& mean_dim_out(
169169
InvalidArgument,
170170
out);
171171

172-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
173-
ET_SWITCH_FLOATH_TYPES(
174-
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
175-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
176-
const size_t num =
177-
torch::executor::get_reduced_dim_product(in, dim_list);
178-
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
179-
CTYPE_OUT sum = 0;
180-
if (in.numel() > 0) {
181-
sum = torch::executor::
182-
map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
183-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
184-
[](CTYPE_OUT outv, CTYPE_OUT acc) {
185-
return acc + outv;
186-
},
187-
in,
188-
dim_list,
189-
out_ix);
190-
}
191-
out_data[out_ix] = sum / static_cast<float>(num);
192-
}
193-
});
194-
});
172+
ET_SWITCH_REALHBBF16_TYPES(
173+
in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
174+
ET_SWITCH_FLOATHBF16_TYPES(
175+
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
176+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
177+
const size_t num =
178+
torch::executor::get_reduced_dim_product(in, dim_list);
179+
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
180+
CTYPE_OUT sum = 0;
181+
if (in.numel() > 0) {
182+
sum = torch::executor::
183+
map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
184+
[](CTYPE_IN v) {
185+
return static_cast<CTYPE_OUT>(v);
186+
},
187+
[](CTYPE_OUT outv, CTYPE_OUT acc) {
188+
return acc + outv;
189+
},
190+
in,
191+
dim_list,
192+
out_ix);
193+
}
194+
out_data[out_ix] = sum / static_cast<float>(num);
195+
}
196+
});
197+
});
195198
}
196199

197200
return out;

0 commit comments

Comments
 (0)