Skip to content

Commit fed29b4

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add scalar cases for binary ops (add, mul, sub, div) on HiFi (#9411)
Summary: As titled. Currently those cases will go to the unoptimized broadcast call, which is extremely inefficient. A simple loop will do much better, and can be further optimized later if needed. Example of gains: mul op goes from 40M to 123k on the 27M ASR encoder. Differential Revision: D71495734
1 parent 4ecfc62 commit fed29b4

File tree

4 files changed

+63
-11
lines changed

4 files changed

+63
-11
lines changed

backends/cadence/hifi/operators/op_add.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,18 @@ Tensor& add_out(
138138
if ((out_type != ScalarType::Float) || (alpha_val != 1.0))
139139
optimized = 0;
140140

141-
if ((a_dim == 0) || (b_dim == 0))
142-
optimized = 0;
141+
bool float_types = (a_type == ScalarType::Float) && (b_type == ScalarType::Float);
142+
143+
if ((a_dim == 0) && float_types) {
144+
for (int i = 0; i < max_dim; i++)
145+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[0] + b.const_data_ptr<float>()[i];
146+
return out;
147+
}
148+
if ((b_dim == 0) && float_types) {
149+
for (int i = 0; i < max_dim; i++)
150+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[i] + b.const_data_ptr<float>()[0];
151+
return out;
152+
}
143153

144154
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
145155
optimized = 0;

backends/cadence/hifi/operators/op_div.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,19 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
8686
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
8787
optimized = 0;
8888

89-
if ((a_dim == 0) || (b_dim == 0))
90-
optimized = 0;
89+
bool float_types = (a_type == ScalarType::Float) && (b_type == ScalarType::Float);
90+
91+
if ((a_dim == 0) && float_types) {
92+
for (int i = 0; i < max_dim; i++)
93+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[0] / b.const_data_ptr<float>()[i];
94+
return out;
95+
}
96+
if ((b_dim == 0) && float_types) {
97+
for (int i = 0; i < max_dim; i++)
98+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[i] / b.const_data_ptr<float>()[0];
99+
return out;
100+
}
101+
91102

92103
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
93104
optimized = 0;
@@ -201,8 +212,18 @@ Tensor& div_out_mode(
201212
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
202213
optimized = 0;
203214

204-
if ((a_dim == 0) || (b_dim == 0))
205-
optimized = 0;
215+
bool float_types = (a_type == ScalarType::Float) && (b_type == ScalarType::Float);
216+
217+
if ((a_dim == 0) && float_types) {
218+
for (int i = 0; i < max_dim; i++)
219+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[0] / b.const_data_ptr<float>()[i];
220+
return out;
221+
}
222+
if ((b_dim == 0) && float_types) {
223+
for (int i = 0; i < max_dim; i++)
224+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[i] / b.const_data_ptr<float>()[0];
225+
return out;
226+
}
206227

207228
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
208229
optimized = 0;

backends/cadence/hifi/operators/op_mul.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,20 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
104104
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
105105
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
106106

107-
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
108-
optimized = 0;
107+
bool float_types = (a_type == ScalarType::Float) && (b_type == ScalarType::Float);
108+
109+
if ((a_dim == 0) && float_types) {
110+
for (int i = 0; i < max_dim; i++)
111+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[0] * b.const_data_ptr<float>()[i];
112+
return out;
113+
}
114+
if ((b_dim == 0) && float_types) {
115+
for (int i = 0; i < max_dim; i++)
116+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[i] * b.const_data_ptr<float>()[0];
117+
return out;
118+
}
109119

110-
if ((a_dim == 0) || (b_dim == 0))
120+
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
111121
optimized = 0;
112122

113123
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))

backends/cadence/hifi/operators/op_sub.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,19 @@ Tensor& sub_out(
133133
if ((out_type != ScalarType::Float) || (alpha_val != 1.0))
134134
optimized = 0;
135135

136-
if ((a_dim == 0) || (b_dim == 0))
137-
optimized = 0;
136+
bool float_types = (a_type == ScalarType::Float) && (b_type == ScalarType::Float);
137+
138+
if ((a_dim == 0) && float_types) {
139+
for (int i = 0; i < max_dim; i++)
140+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[0] - b.const_data_ptr<float>()[i];
141+
return out;
142+
}
143+
if ((b_dim == 0) && float_types) {
144+
for (int i = 0; i < max_dim; i++)
145+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[i] - b.const_data_ptr<float>()[0];
146+
return out;
147+
}
148+
138149

139150
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
140151
optimized = 0;

0 commit comments

Comments
 (0)