Skip to content

Commit bf7874f

Browse files
committed
[Executorch] Handle broadcast semantics for last dim
Pull Request resolved: #6521 This diff add support to handle element wise mul op when broadcast is across last dim ghstack-source-id: 251185496 @exported-using-ghexport Differential Revision: [D64156863](https://our.internmc.facebook.com/intern/diff/D64156863/)
1 parent 77fe041 commit bf7874f

File tree

5 files changed

+241
-61
lines changed

5 files changed

+241
-61
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,27 @@ enum class ElementwiseOptimizedPath {
4343
kBroadcast2dBy1dReverseArguments,
4444
kBroadcastNdByNd,
4545
kBroadcastNdByNdReverseArguments,
46+
kBroadcastLastDim,
47+
kBroadcastLastDimReverseArguments,
4648
};
4749

4850
namespace internal {
4951

50-
// Find the single broadcast dimension if it exists.
51-
// This path aims to handle broadcast of the following form
52-
// A = [a1, a2,., 1, .., an]
53-
// B = [b1, b2,., bm, .., bn]
54-
// OR
55-
// A = [a1, a2,., am, .., an]
56-
// B = [b1, b2,., 1, .., bn]
52+
/*
53+
Given two tensors, this function returns the broadcast dim if it exists.
54+
Returns 0 if no broadcast dim is found.
55+
Else negative index is used to indicate broadcast dim
56+
e.g. if size = [a, b, c, 1, e, f] then broadcast dim is -3
57+
58+
This path aims to handle broadcast of the following form
59+
A = [a1, a2,., 1, .., an]
60+
B = [b1, b2,., bm, .., bn]
61+
OR
62+
A = [a1, a2,., am, .., an]
63+
B = [b1, b2,., 1, .., bn]
64+
Note that this way of determining broadcast dim also works
65+
when broadcast dim is the last dim.
66+
*/
5767
int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) {
5868
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
5969
auto lhs_end = lhs.sizes().end();
@@ -125,6 +135,14 @@ inline ElementwiseOptimizedPath select_broadcast_optimized_path(
125135
} else {
126136
return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
127137
}
138+
} else if (broadcast_dim == -1) {
139+
if (std::count_if(lhs_begin, lhs_end, [](Tensor::SizesType x) {
140+
return x == 1;
141+
}) == 1) {
142+
return ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments;
143+
} else {
144+
return ElementwiseOptimizedPath::kBroadcastLastDim;
145+
}
128146
}
129147
return ElementwiseOptimizedPath::kNone;
130148
}

kernels/optimized/cpu/op_mul.cpp

Lines changed: 111 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/kernels/optimized/vec/vec.h>
1212
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1313
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
1415
#include <executorch/runtime/kernel/kernel_includes.h>
1516
#include <executorch/runtime/platform/assert.h>
1617

@@ -66,6 +67,115 @@ template <
6667
typename CTYPE_OUT>
6768
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
6869
: public ReportCanCastBug {};
70+
71+
Tensor& handle_last_dim_broadcast(
72+
KernelRuntimeContext& ctx,
73+
const Tensor& a,
74+
const Tensor& b,
75+
Tensor& out,
76+
const ElementwiseOptimizedPath selected_optimized_path) {
77+
ScalarType out_type = out.scalar_type();
78+
const Tensor* lhs;
79+
const Tensor* rhs;
80+
if (selected_optimized_path ==
81+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
82+
lhs = &b;
83+
rhs = &a;
84+
} else {
85+
lhs = &a;
86+
rhs = &b;
87+
}
88+
auto error = resize_tensor(out, lhs->sizes());
89+
ET_KERNEL_CHECK_MSG(
90+
ctx,
91+
error == Error::Ok,
92+
InvalidArgument,
93+
out,
94+
"Failed to resize output tensor.");
95+
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
96+
const auto broadcast_size = out.size(out.dim() - 1);
97+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
98+
using Vec = executorch::vec::Vectorized<CTYPE>;
99+
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
100+
[](Vec x, Vec y) { return x * y; },
101+
out.mutable_data_ptr<CTYPE>(),
102+
lhs->const_data_ptr<CTYPE>(),
103+
rhs->const_data_ptr<CTYPE>(),
104+
outer_size,
105+
broadcast_size);
106+
});
107+
return out;
108+
}
109+
110+
Tensor& handle_broadcast_mul(
111+
KernelRuntimeContext& ctx,
112+
const Tensor& a,
113+
const Tensor& b,
114+
Tensor& out,
115+
const ElementwiseOptimizedPath selected_optimized_path) {
116+
if ((selected_optimized_path ==
117+
ElementwiseOptimizedPath::kBroadcastLastDim) ||
118+
(selected_optimized_path ==
119+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
120+
return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path);
121+
}
122+
123+
ScalarType out_type = out.scalar_type();
124+
const Tensor* lhs;
125+
const Tensor* rhs;
126+
if ((selected_optimized_path ==
127+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
128+
(selected_optimized_path ==
129+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
130+
lhs = &b;
131+
rhs = &a;
132+
} else {
133+
// Catch failure to update logic when adding new broadcasting possibility.
134+
ET_DCHECK(
135+
(selected_optimized_path ==
136+
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
137+
(selected_optimized_path ==
138+
ElementwiseOptimizedPath::kBroadcastNdByNd));
139+
lhs = &a;
140+
rhs = &b;
141+
}
142+
auto error = resize_tensor(out, lhs->sizes());
143+
ET_KERNEL_CHECK_MSG(
144+
ctx,
145+
error == Error::Ok,
146+
InvalidArgument,
147+
out,
148+
"Failed to resize output tensor.");
149+
int64_t outer_size = 1;
150+
int64_t broadcast_size;
151+
int64_t inner_size;
152+
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
153+
(selected_optimized_path ==
154+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
155+
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
156+
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
157+
auto normalized_tensor_size_lhs =
158+
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
159+
outer_size = normalized_tensor_size_lhs[0];
160+
broadcast_size = normalized_tensor_size_lhs[1];
161+
inner_size = normalized_tensor_size_lhs[2];
162+
} else {
163+
broadcast_size = lhs->sizes()[lhs->dim() - 2];
164+
inner_size = lhs->sizes()[lhs->dim() - 1];
165+
}
166+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
167+
using Vec = executorch::vec::Vectorized<CTYPE>;
168+
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
169+
[](Vec x, Vec y) { return x * y; },
170+
out.mutable_data_ptr<CTYPE>(),
171+
lhs->const_data_ptr<CTYPE>(),
172+
rhs->const_data_ptr<CTYPE>(),
173+
outer_size,
174+
broadcast_size,
175+
inner_size);
176+
});
177+
return out;
178+
}
69179
} // namespace
70180

71181
Tensor& opt_mul_out(
@@ -128,60 +238,7 @@ Tensor& opt_mul_out(
128238
out.numel());
129239
});
130240
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
131-
const Tensor* lhs;
132-
const Tensor* rhs;
133-
if ((selected_optimized_path ==
134-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
135-
(selected_optimized_path ==
136-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
137-
lhs = &b;
138-
rhs = &a;
139-
} else {
140-
// Catch failure to update logic when adding new broadcasting possibility.
141-
ET_DCHECK(
142-
(selected_optimized_path ==
143-
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
144-
(selected_optimized_path ==
145-
ElementwiseOptimizedPath::kBroadcastNdByNd));
146-
lhs = &a;
147-
rhs = &b;
148-
}
149-
auto error = resize_tensor(out, lhs->sizes());
150-
ET_KERNEL_CHECK_MSG(
151-
ctx,
152-
error == Error::Ok,
153-
InvalidArgument,
154-
out,
155-
"Failed to resize output tensor.");
156-
int64_t outer_size = 1;
157-
int64_t broadcast_size;
158-
int64_t inner_size;
159-
if ((selected_optimized_path ==
160-
ElementwiseOptimizedPath::kBroadcastNdByNd) ||
161-
(selected_optimized_path ==
162-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
163-
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
164-
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
165-
auto normalized_tensor_size_lhs =
166-
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
167-
outer_size = normalized_tensor_size_lhs[0];
168-
broadcast_size = normalized_tensor_size_lhs[1];
169-
inner_size = normalized_tensor_size_lhs[2];
170-
} else {
171-
broadcast_size = lhs->sizes()[lhs->dim() - 2];
172-
inner_size = lhs->sizes()[lhs->dim() - 1];
173-
}
174-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
175-
using Vec = executorch::vec::Vectorized<CTYPE>;
176-
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
177-
[](Vec x, Vec y) { return x * y; },
178-
out.mutable_data_ptr<CTYPE>(),
179-
lhs->const_data_ptr<CTYPE>(),
180-
rhs->const_data_ptr<CTYPE>(),
181-
outer_size,
182-
broadcast_size,
183-
inner_size);
184-
});
241+
return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path);
185242
} else {
186243
ScalarType common_type =
187244
promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ _OPTIMIZED_ATEN_OPS = (
7272
":binary_ops",
7373
"//executorch/kernels/portable/cpu:scalar_utils",
7474
"//executorch/kernels/portable/cpu/util:broadcast_util",
75+
"//executorch/runtime/core/exec_aten/util:tensor_util",
7576
],
7677
),
7778
op_target(

kernels/optimized/vec/functional_base.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,5 +380,43 @@ inline void broadcasting_map_2d_by_1d(
380380
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
381381
}
382382

383+
/*
384+
Following function is used to implement broadcasting binary operation on two tensors
385+
where lhs tensor is treated to be of shape [outer_size, broadcast_size] and
386+
rhs tensor is treated to be of shape [outer_size, 1]
387+
Any two N dimensional tensors can be mapped to this formula
388+
when lhs size = [lhs0, lhs1, ..., lhsN-1] and rhs size = [rhs0, rhs1, ..., 1]
389+
by viewing the two tensors as
390+
lhs size = [lsh0 * lsh1 * ... * lshN-2, lhsN-1]
391+
rhs size = [rsh0 * rsh1 * ... * rshN-2, 1]
392+
*/
393+
template <typename scalar_t, typename Op>
394+
inline void broadcasting_map_broadcast_last_dim(
395+
const Op& vec_fun,
396+
scalar_t* output_data,
397+
const scalar_t* lhs,
398+
const scalar_t* rhs,
399+
int64_t outer_size,
400+
int64_t broadcast_size) {
401+
using Vec = vec::Vectorized<scalar_t>;
402+
int64_t outer_stride_lhs = broadcast_size;
403+
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
404+
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
405+
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
406+
int64_t inner_idx = 0;
407+
Vec data_vec2 = Vec(rhs[outer_idx]);
408+
for (; inner_idx < broadcast_size - (broadcast_size % Vec::size()); inner_idx += Vec::size()) {
409+
Vec data_vec = Vec::loadu(lhs_outer + inner_idx);
410+
Vec output_vec = vec_fun(data_vec, data_vec2);
411+
output_vec.store(output_data_row + inner_idx);
412+
}
413+
if (broadcast_size - inner_idx > 0) {
414+
Vec data_vec = Vec::loadu(lhs_outer + inner_idx, broadcast_size - inner_idx);
415+
Vec output_vec = vec_fun(data_vec, data_vec2);
416+
output_vec.store(output_data_row + inner_idx, broadcast_size - inner_idx);
417+
}
418+
}
419+
}
420+
383421
} // namespace vec
384422
} // namespace executorch

kernels/test/op_mul_test.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,60 @@ class OpMulOutTest : public OperatorTest {
220220
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
221221
}
222222

223+
template <ScalarType DTYPE>
224+
void test_broadcast_last_dim() {
225+
TensorFactory<DTYPE> tf_a;
226+
227+
Tensor a =
228+
tf_a.make({4, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
229+
Tensor b = tf_a.make({4, 1}, /*data=*/{2, 3, 4, 5});
230+
231+
// Destination for output of mul.
232+
Tensor out = tf_a.zeros({4, 3});
233+
Tensor expected = tf_a.make(
234+
{4, 3}, /*data=*/{2, 4, 6, 12, 15, 18, 28, 32, 36, 50, 55, 60});
235+
236+
// Check that it matches the expected output.
237+
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
238+
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
239+
240+
a = tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
241+
b = tf_a.make({2, 2, 1}, /*data=*/{2, 3, 4, 5});
242+
243+
// Destination for output of mul.
244+
out = tf_a.zeros({2, 2, 3});
245+
expected = tf_a.make(
246+
{2, 2, 3}, /*data=*/{2, 4, 6, 12, 15, 18, 28, 32, 36, 50, 55, 60});
247+
248+
// Check that it matches the expected output.
249+
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
250+
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
251+
252+
a = tf_a.make(
253+
{2, 2, 3, 5},
254+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
255+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
256+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
257+
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
258+
b = tf_a.make(
259+
{2, 2, 3, 1},
260+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
261+
262+
// Destination for output of mul.
263+
out = tf_a.zeros({2, 2, 3, 5});
264+
expected = tf_a.make(
265+
{2, 2, 3, 5},
266+
/*data=*/{1, 2, 3, 4, 5, 12, 14, 16, 18, 20, 33, 36,
267+
39, 42, 45, 64, 68, 72, 76, 80, 105, 110, 115, 120,
268+
125, 156, 162, 168, 174, 180, 217, 224, 231, 238, 245, 288,
269+
296, 304, 312, 320, 369, 378, 387, 396, 405, 460, 470, 480,
270+
490, 500, 561, 572, 583, 594, 605, 672, 684, 696, 708, 720});
271+
272+
// Check that it matches the expected output.
273+
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
274+
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
275+
}
276+
223277
template <ScalarType DTYPE>
224278
void test_broadcast_b2a() {
225279
TensorFactory<DTYPE> tf_a;
@@ -392,6 +446,18 @@ TEST_F(OpMulOutTest, BroadcastNDTest) {
392446
test_broadcast_4D<ScalarType::Float>();
393447
test_broadcast_4D<ScalarType::Half>();
394448
test_broadcast_4D<ScalarType::BFloat16>();
449+
450+
// Test broadcasting on the last dimension
451+
test_broadcast_last_dim<ScalarType::Float>();
452+
test_broadcast_last_dim<ScalarType::Half>();
453+
test_broadcast_last_dim<ScalarType::BFloat16>();
454+
}
455+
456+
TEST_F(OpMulOutTest, BroadcastLastDimTest) {
457+
// Test broadcasting on the last dimension
458+
test_broadcast_last_dim<ScalarType::Float>();
459+
test_broadcast_last_dim<ScalarType::Half>();
460+
test_broadcast_last_dim<ScalarType::BFloat16>();
395461
}
396462

397463
// Broadcast tensor a and b's size to a new size c.

0 commit comments

Comments
 (0)