Skip to content

Commit 2a55cb4

Browse files
committed
[executorch] Add vectorized scalar path for single-element Tensor passed to optimized mul
We are currently doing slow broadcasting for this case. After this diff, we should get nice vectorization. Differential Revision: [D61560825](https://our.internmc.facebook.com/intern/diff/D61560825/) ghstack-source-id: 238994617 Pull Request resolved: #4807
1 parent fee9377 commit 2a55cb4

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

kernels/optimized/cpu/op_mul.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,15 @@ template <
106106
typename CTYPE_OUT>
107107
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
108108
: public ReportCanCastBug {};
109+
110+
Scalar tensor_to_scalar(RuntimeContext& ctx, const Tensor& t) {
111+
ET_DCHECK(t.numel() == 1);
112+
Scalar result;
113+
ET_SWITCH_REALHB_TYPES(t.scalar_type(), ctx, "mul.out", CTYPE, [&]() {
114+
result = Scalar(*t.const_data_ptr<CTYPE>());
115+
});
116+
return result;
117+
}
109118
} // namespace
110119

111120
Tensor& opt_mul_out(
@@ -119,6 +128,35 @@ Tensor& opt_mul_out(
119128
ScalarType b_type = b.scalar_type();
120129
ScalarType out_type = out.scalar_type();
121130

131+
if (b.numel() == 1) {
132+
if (a_type == b_type && a_type == out_type &&
133+
a_type != ScalarType::Half) {
134+
auto error = resize_tensor(out, a.sizes());
135+
ET_KERNEL_CHECK_MSG(
136+
ctx,
137+
error == Error::Ok,
138+
InvalidArgument,
139+
out,
140+
"Failed to resize output tensor.");
141+
ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.out", CTYPE, [&]() {
142+
ET_SWITCH_REALB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
143+
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
144+
CTYPE b_casted = static_cast<CTYPE>(b_val);
145+
146+
using Vec = executorch::vec::Vectorized<CTYPE>;
147+
executorch::vec::map<CTYPE>(
148+
[b_casted](Vec x) { return x * Vec(b_casted); },
149+
out.mutable_data_ptr<CTYPE>(),
150+
a.const_data_ptr<CTYPE>(),
151+
out.numel());
152+
});
153+
});
154+
return out;
155+
}
156+
} else if (a.numel() == 1) {
157+
return opt_mul_out(ctx, b, a, out);
158+
}
159+
122160
if (can_use_optimized_path(a, b, out)) {
123161
// Resize for dynamic shape
124162
auto error = resize_tensor(out, a.sizes());

0 commit comments

Comments
 (0)