Skip to content

Commit fb0a6e1

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add scalar cases for add and mul on HiFi
Summary: As titled. Currently those cases will go to the unoptimized broadcast call, which is extremely inefficient. A simple loop will do much better, and can be further optimized later if needed. Example of gains: mul op goes from 40M to 123k on the 27M ASR encoder. Differential Revision: D71495734
1 parent e0235f0 commit fb0a6e1

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

backends/cadence/hifi/operators/op_add.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,18 @@ Tensor& add_out(
138138
if ((out_type != ScalarType::Float) || (alpha_val != 1.0))
139139
optimized = 0;
140140

141-
if ((a_dim == 0) || (b_dim == 0))
142-
optimized = 0;
141+
bool float_types = (a_type == ScalarType::Float) && (b_type == ScalarType::Float);
142+
143+
if ((a_dim == 0) && float_types) {
144+
for (int i = 0; i < max_dim; i++)
145+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[0] + b.const_data_ptr<float>()[i];
146+
return out;
147+
}
148+
if ((b_dim == 0) && float_types) {
149+
for (int i = 0; i < max_dim; i++)
150+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[i] + b.const_data_ptr<float>()[0];
151+
return out;
152+
}
143153

144154
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
145155
optimized = 0;

backends/cadence/hifi/operators/op_mul.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,20 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
104104
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
105105
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
106106

107-
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
108-
optimized = 0;
107+
bool float_types = (a_type == ScalarType::Float) && (b_type == ScalarType::Float);
108+
109+
if ((a_dim == 0) && float_types) {
110+
for (int i = 0; i < max_dim; i++)
111+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[0] * b.const_data_ptr<float>()[i];
112+
return out;
113+
}
114+
if ((b_dim == 0) && float_types) {
115+
for (int i = 0; i < max_dim; i++)
116+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[i] * b.const_data_ptr<float>()[0];
117+
return out;
118+
}
109119

110-
if ((a_dim == 0) || (b_dim == 0))
120+
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
111121
optimized = 0;
112122

113123
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))

0 commit comments

Comments
 (0)