Skip to content

Commit 87dd81a

Browse files
authored
[Cadence] Add scalar cases for binary ops (add, mul, sub, div) on HiFi
Differential Revision: D71495734 Pull Request resolved: #9411
1 parent 4903f0a commit 87dd81a

File tree

4 files changed

+76
-11
lines changed

4 files changed

+76
-11
lines changed

backends/cadence/hifi/operators/op_add.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,21 @@ Tensor& add_out(
138138
if ((out_type != ScalarType::Float) || (alpha_val != 1.0))
139139
optimized = 0;
140140

141-
if ((a_dim == 0) || (b_dim == 0))
142-
optimized = 0;
141+
bool float_types =
142+
(a_type == ScalarType::Float) && (b_type == ScalarType::Float);
143+
144+
if ((a_dim == 0) && float_types) {
145+
for (int i = 0; i < max_dim; i++)
146+
out.mutable_data_ptr<float>()[i] =
147+
a.const_data_ptr<float>()[0] + b.const_data_ptr<float>()[i];
148+
return out;
149+
}
150+
if ((b_dim == 0) && float_types) {
151+
for (int i = 0; i < max_dim; i++)
152+
out.mutable_data_ptr<float>()[i] =
153+
a.const_data_ptr<float>()[i] + b.const_data_ptr<float>()[0];
154+
return out;
155+
}
143156

144157
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
145158
optimized = 0;

backends/cadence/hifi/operators/op_div.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,21 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
8686
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
8787
optimized = 0;
8888

89-
if ((a_dim == 0) || (b_dim == 0))
90-
optimized = 0;
89+
bool float_types =
90+
(a_type == ScalarType::Float) && (b_type == ScalarType::Float);
91+
92+
if ((a_dim == 0) && float_types) {
93+
for (int i = 0; i < max_dim; i++)
94+
out.mutable_data_ptr<float>()[i] =
95+
a.const_data_ptr<float>()[0] / b.const_data_ptr<float>()[i];
96+
return out;
97+
}
98+
if ((b_dim == 0) && float_types) {
99+
for (int i = 0; i < max_dim; i++)
100+
out.mutable_data_ptr<float>()[i] =
101+
a.const_data_ptr<float>()[i] / b.const_data_ptr<float>()[0];
102+
return out;
103+
}
91104

92105
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
93106
optimized = 0;
@@ -201,8 +214,21 @@ Tensor& div_out_mode(
201214
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
202215
optimized = 0;
203216

204-
if ((a_dim == 0) || (b_dim == 0))
205-
optimized = 0;
217+
bool float_types =
218+
(a_type == ScalarType::Float) && (b_type == ScalarType::Float);
219+
220+
if ((a_dim == 0) && float_types) {
221+
for (int i = 0; i < max_dim; i++)
222+
out.mutable_data_ptr<float>()[i] =
223+
a.const_data_ptr<float>()[0] / b.const_data_ptr<float>()[i];
224+
return out;
225+
}
226+
if ((b_dim == 0) && float_types) {
227+
for (int i = 0; i < max_dim; i++)
228+
out.mutable_data_ptr<float>()[i] =
229+
a.const_data_ptr<float>()[i] / b.const_data_ptr<float>()[0];
230+
return out;
231+
}
206232

207233
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
208234
optimized = 0;

backends/cadence/hifi/operators/op_mul.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,23 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
104104
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
105105
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
106106

107-
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
108-
optimized = 0;
107+
bool float_types =
108+
(a_type == ScalarType::Float) && (b_type == ScalarType::Float);
109+
110+
if ((a_dim == 0) && float_types) {
111+
for (int i = 0; i < max_dim; i++)
112+
out.mutable_data_ptr<float>()[i] =
113+
a.const_data_ptr<float>()[0] * b.const_data_ptr<float>()[i];
114+
return out;
115+
}
116+
if ((b_dim == 0) && float_types) {
117+
for (int i = 0; i < max_dim; i++)
118+
out.mutable_data_ptr<float>()[i] =
119+
a.const_data_ptr<float>()[i] * b.const_data_ptr<float>()[0];
120+
return out;
121+
}
109122

110-
if ((a_dim == 0) || (b_dim == 0))
123+
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
111124
optimized = 0;
112125

113126
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))

backends/cadence/hifi/operators/op_sub.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,21 @@ Tensor& sub_out(
133133
if ((out_type != ScalarType::Float) || (alpha_val != 1.0))
134134
optimized = 0;
135135

136-
if ((a_dim == 0) || (b_dim == 0))
137-
optimized = 0;
136+
bool float_types =
137+
(a_type == ScalarType::Float) && (b_type == ScalarType::Float);
138+
139+
if ((a_dim == 0) && float_types) {
140+
for (int i = 0; i < max_dim; i++)
141+
out.mutable_data_ptr<float>()[i] =
142+
a.const_data_ptr<float>()[0] - b.const_data_ptr<float>()[i];
143+
return out;
144+
}
145+
if ((b_dim == 0) && float_types) {
146+
for (int i = 0; i < max_dim; i++)
147+
out.mutable_data_ptr<float>()[i] =
148+
a.const_data_ptr<float>()[i] - b.const_data_ptr<float>()[0];
149+
return out;
150+
}
138151

139152
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
140153
optimized = 0;

0 commit comments

Comments
 (0)