Skip to content

Commit 229b027

Browse files
committed
[executorch] Optimized 2D-by-1D broadcasting in optimized op_mul
Pull Request resolved: #4808 Detect that we are doing an elementwise multiplication for a 2D tensor and a 1D tensor. Dispatch to a vectorized kernel for this case. ghstack-source-id: 239747531 @exported-using-ghexport Differential Revision: [D61560826](https://our.internmc.facebook.com/intern/diff/D61560826/)
1 parent fcf7aba commit 229b027

File tree

3 files changed

+139
-35
lines changed

3 files changed

+139
-35
lines changed

kernels/optimized/cpu/op_mul.cpp

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,45 +22,74 @@ using ScalarType = exec_aten::ScalarType;
2222

2323
namespace {
2424

25+
// NOTE: we bake ArrayRef iterators being pointers into the return
26+
// type here because we assume that iterators are portable across
27+
// ArrayRef copies.
28+
const Tensor::SizesType* arrayref_begin_ignoring_leading_1s(
29+
ArrayRef<Tensor::SizesType> arr) {
30+
return std::find_if(
31+
arr.begin(), arr.end(), [](Tensor::SizesType x) { return x != 1; });
32+
}
33+
2534
bool sizes_match_ignoring_leading_1s(
2635
ArrayRef<Tensor::SizesType> lhs,
2736
ArrayRef<Tensor::SizesType> rhs) {
28-
auto lhs_begin = lhs.begin();
37+
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs);
2938
auto lhs_end = lhs.end();
30-
while (lhs_begin != lhs_end && *lhs_begin == 1) {
31-
++lhs_begin;
32-
}
3339

34-
auto rhs_begin = rhs.begin();
40+
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs);
3541
auto rhs_end = rhs.end();
36-
while (rhs_begin != rhs_end && *rhs_begin == 1) {
37-
++rhs_begin;
38-
}
3942

4043
return ((lhs_end - lhs_begin) == (rhs_end - rhs_begin)) &&
4144
std::equal(lhs_begin, lhs_end, rhs_begin);
4245
}
4346

4447
// Move to generic util as this is applicable to all binary ops
45-
bool can_use_optimized_path(
46-
const Tensor& a,
47-
const Tensor& b,
48-
const Tensor& out) {
48+
enum class ElementwiseOptimizedPath {
49+
kNone,
50+
kTreatAs1d,
51+
kBroadcast2dBy1d,
52+
kBroadcast2dBy1dReverseArguments,
53+
};
54+
55+
ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
56+
const Tensor& lhs,
57+
const Tensor& rhs) {
58+
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
59+
auto lhs_end = lhs.sizes().end();
60+
61+
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
62+
auto rhs_end = rhs.sizes().end();
63+
64+
const auto lhs_size = lhs_end - lhs_begin;
65+
const auto rhs_size = rhs_end - rhs_begin;
66+
if (lhs_size == 2 && rhs_size == 1 && lhs_begin[1] == rhs_begin[0]) {
67+
return ElementwiseOptimizedPath::kBroadcast2dBy1d;
68+
}
69+
70+
if (lhs_size == 1 && rhs_size == 2 && rhs_begin[1] == lhs_begin[0]) {
71+
return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
72+
}
73+
74+
return ElementwiseOptimizedPath::kNone;
75+
}
76+
77+
ElementwiseOptimizedPath
78+
select_optimized_path(const Tensor& a, const Tensor& b, const Tensor& out) {
4979
ScalarType a_type = a.scalar_type();
5080
ScalarType b_type = b.scalar_type();
5181
ScalarType out_type = out.scalar_type();
5282

53-
bool can_use_optimized_path = true;
54-
can_use_optimized_path =
55-
can_use_optimized_path && ((a_type == b_type) && (a_type == out_type));
56-
can_use_optimized_path = can_use_optimized_path &&
57-
(a_type != ScalarType::Half && b_type != ScalarType::Half);
58-
can_use_optimized_path = can_use_optimized_path &&
59-
(a.sizes().equals(b.sizes()) ||
60-
(a.numel() == b.numel() &&
61-
(a.numel() == out.numel() ||
62-
sizes_match_ignoring_leading_1s(a.sizes(), b.sizes()))));
63-
return can_use_optimized_path;
83+
if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half) {
84+
return ElementwiseOptimizedPath::kNone;
85+
}
86+
if (a.sizes().equals(b.sizes()) ||
87+
(a.numel() == b.numel() &&
88+
(a.numel() == out.numel() ||
89+
sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
90+
return ElementwiseOptimizedPath::kTreatAs1d;
91+
}
92+
return select_broadcast_2d_by_1d_optimized_path(a, b);
6493
}
6594

6695
template <
@@ -147,7 +176,8 @@ Tensor& opt_mul_out(
147176
return opt_mul_out(ctx, b, a, out);
148177
}
149178

150-
if (can_use_optimized_path(a, b, out)) {
179+
auto selected_optimized_path = select_optimized_path(a, b, out);
180+
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
151181
// Resize for dynamic shape
152182
auto error = resize_tensor(out, a.sizes());
153183
ET_KERNEL_CHECK_MSG(
@@ -166,6 +196,38 @@ Tensor& opt_mul_out(
166196
b.const_data_ptr<CTYPE>(),
167197
out.numel());
168198
});
199+
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
200+
const Tensor* lhs;
201+
const Tensor* rhs;
202+
if (selected_optimized_path ==
203+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
204+
lhs = &b;
205+
rhs = &a;
206+
} else {
207+
// Catch failure to update logic when adding new broadcasting possibility.
208+
ET_DCHECK(
209+
selected_optimized_path ==
210+
ElementwiseOptimizedPath::kBroadcast2dBy1d);
211+
lhs = &a;
212+
rhs = &b;
213+
}
214+
auto error = resize_tensor(out, lhs->sizes());
215+
ET_KERNEL_CHECK_MSG(
216+
ctx,
217+
error == Error::Ok,
218+
InvalidArgument,
219+
out,
220+
"Failed to resize output tensor.");
221+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
222+
using Vec = executorch::vec::Vectorized<CTYPE>;
223+
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
224+
[](Vec x, Vec y) { return x * y; },
225+
out.mutable_data_ptr<CTYPE>(),
226+
lhs->const_data_ptr<CTYPE>(),
227+
rhs->const_data_ptr<CTYPE>(),
228+
lhs->sizes()[lhs->dim() - 2],
229+
lhs->sizes()[lhs->dim() - 1]);
230+
});
169231
} else {
170232
ScalarType common_type =
171233
promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/optimized/vec/functional_base.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,5 +325,40 @@ inline void map4(
325325
}
326326
}
327327

328+
329+
// Map vec_fun across input_data and input_data2, where input_data is
330+
// a two-dimensional array of size (size, size2), input_data2 is a
331+
// one-dimensional array of size size2, and input_data2 is broadcast
332+
// to be of size (size, size2).
333+
template <typename scalar_t, typename Op>
334+
inline void broadcasting_map_2d_by_1d(
335+
const Op& vec_fun,
336+
scalar_t* output_data,
337+
const scalar_t* input_data,
338+
const scalar_t* input_data2,
339+
int64_t size,
340+
int64_t size2) {
341+
using Vec = vec::Vectorized<scalar_t>;
342+
for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) {
343+
const scalar_t* input_data_row = input_data + outer_idx * size2;
344+
scalar_t* output_data_row = output_data + outer_idx * size2;
345+
int64_t inner_idx = 0;
346+
for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) {
347+
Vec data_vec = Vec::loadu(input_data_row + inner_idx);
348+
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx);
349+
Vec output_vec = vec_fun(data_vec, data_vec2);
350+
output_vec.store(output_data_row + inner_idx);
351+
}
352+
if (size2 - inner_idx > 0) {
353+
Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx);
354+
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx);
355+
Vec output_vec = vec_fun(data_vec, data_vec2);
356+
output_vec.store(output_data_row + inner_idx, size2 - inner_idx);
357+
}
358+
}
359+
}
360+
361+
362+
328363
} // namespace vec
329364
} // namespace executorch

kernels/test/op_mul_test.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -181,19 +181,19 @@ TEST_F(OpMulOutTest, OptimizedPathIgnoresLeading1Dimensions) {
181181
}
182182

183183
// Mismatched shape tests.
184-
TEST_F(OpMulOutTest, MismatchedInputShapesDies) {
184+
TEST_F(OpMulOutTest, MismatchedNonBroadcastableInputShapesDies) {
185185
if (SupportedFeatures::get()->is_aten) {
186186
GTEST_SKIP() << "ATen currently supports mismatched shapes";
187187
}
188188

189189
TensorFactory<ScalarType::Int> tf;
190190

191191
// Input tensors with different shapes.
192-
Tensor a = tf.ones(/*sizes=*/{1, 2});
192+
Tensor a = tf.ones(/*sizes=*/{4, 2});
193193
Tensor b = tf.ones(/*sizes=*/{2, 2});
194194

195195
// Output tensor; matches the shape of one of the inputs.
196-
Tensor out = tf.zeros(/*sizes=*/{4});
196+
Tensor out = tf.zeros(/*sizes=*/{8});
197197

198198
// Multiplying the two mismatched tensors should cause an assertion and kill
199199
// the test process.
@@ -204,16 +204,22 @@ TEST_F(OpMulOutTest, MismatchedInputShapesDies) {
204204
TEST_F(OpMulOutTest, BroadcastA2BTest) {
205205
TensorFactory<ScalarType::Int> tf_a;
206206

207-
// a and b of different shapes
208-
Tensor a = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});
209-
Tensor b = tf_a.make({2}, /*data=*/{2, 2});
207+
std::vector<std::vector<int32_t>> b_sizeses = {
208+
{2},
209+
{1, 2},
210+
};
211+
for (const auto& b_sizes : b_sizeses) {
212+
// a and b of different shapes
213+
Tensor a = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});
214+
Tensor b = tf_a.make(b_sizes, /*data=*/{2, 2});
210215

211-
// Destination for output of mul.
212-
Tensor out = tf_a.zeros({2, 2});
216+
// Destination for output of mul.
217+
Tensor out = tf_a.zeros({2, 2});
213218

214-
// Check that it matches the expected output.
215-
EXPECT_TENSOR_CLOSE(
216-
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
219+
// Check that it matches the expected output.
220+
EXPECT_TENSOR_CLOSE(
221+
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
222+
}
217223
}
218224

219225
// Broadcast tensor a's size to tensor b's size
@@ -262,6 +268,7 @@ TEST_F(OpMulOutTest, ScalarInputBroadcastTest) {
262268

263269
// Check that it matches the expected output.
264270
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
271+
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
265272
}
266273

267274
TEST_F(OpMulOutTest, MismatchedOutputShapesDies) {

0 commit comments

Comments
 (0)