Skip to content

Commit 77fe041

Browse files
[Executorch] mul broadcast update (#6589)
Pull Request resolved: #6520 Handle broadcast for > 2D tensors in optimized library. For now broadcast across only non 0th and (N-1)st dim is supported in optimized path. ghstack-source-id: 251021628 @exported-using-ghexport Differential Revision: [D64156862](https://our.internmc.facebook.com/intern/diff/D64156862/) --------- Co-authored-by: Kimish Patel <[email protected]>
1 parent 3a1538a commit 77fe041

File tree

4 files changed

+249
-33
lines changed

4 files changed

+249
-33
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,62 @@ enum class ElementwiseOptimizedPath {
4141
kTreatAs1d,
4242
kBroadcast2dBy1d,
4343
kBroadcast2dBy1dReverseArguments,
44+
kBroadcastNdByNd,
45+
kBroadcastNdByNdReverseArguments,
4446
};
4547

4648
namespace internal {
47-
inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
49+
50+
// Find the single broadcast dimension if it exists.
51+
// This path aims to handle broadcast of the following form
52+
// A = [a1, a2,., 1, .., an]
53+
// B = [b1, b2,., bm, .., bn]
54+
// OR
55+
// A = [a1, a2,., am, .., an]
56+
// B = [b1, b2,., 1, .., bn]
57+
int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) {
58+
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
59+
auto lhs_end = lhs.sizes().end();
60+
61+
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
62+
auto rhs_end = rhs.sizes().end();
63+
64+
const auto lhs_size = lhs_end - lhs_begin;
65+
const auto rhs_size = rhs_end - rhs_begin;
66+
67+
// Following example is not handled at the moment
68+
// [1, 3, 4, 5]
69+
// [2, 3, 4, 5]
70+
if (lhs_size != rhs_size) {
71+
return 0;
72+
}
73+
74+
int32_t broadcast_dim = 0;
75+
// Check
76+
// 1. if any dim value is 1 (it constitutes a broadcast dim)
77+
// 2. If more than one dim value is 1 (we cannot handle)
78+
// 3. If non-1 dim values are equal
79+
lhs_end--;
80+
rhs_end--;
81+
while (lhs_end != lhs_begin) {
82+
if (*lhs_end == 1 || *rhs_end == 1) {
83+
// If more than one broadcast dim is found, return 0.
84+
if (broadcast_dim != 0) {
85+
return 0;
86+
}
87+
// negative index is used
88+
broadcast_dim = lhs_end - lhs.sizes().end();
89+
} else if (*lhs_end != *rhs_end) {
90+
// If non-1 dim values are not equal, return 0.
91+
return 0;
92+
}
93+
lhs_end--;
94+
rhs_end--;
95+
}
96+
return broadcast_dim;
97+
}
98+
99+
inline ElementwiseOptimizedPath select_broadcast_optimized_path(
48100
const Tensor& lhs,
49101
const Tensor& rhs) {
50102
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
@@ -63,6 +115,17 @@ inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
63115
return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
64116
}
65117

118+
int32_t broadcast_dim = get_broadcast_dim(lhs, rhs);
119+
// Right now we dont handle last dim broadcast
120+
if (broadcast_dim < -1) {
121+
if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) {
122+
return x == 1;
123+
}) == 1) {
124+
return ElementwiseOptimizedPath::kBroadcastNdByNd;
125+
} else {
126+
return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
127+
}
128+
}
66129
return ElementwiseOptimizedPath::kNone;
67130
}
68131
} // namespace internal
@@ -85,7 +148,28 @@ ElementwiseOptimizedPath inline select_optimized_path(
85148
internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
86149
return ElementwiseOptimizedPath::kTreatAs1d;
87150
}
88-
return internal::select_broadcast_2d_by_1d_optimized_path(a, b);
151+
return internal::select_broadcast_optimized_path(a, b);
152+
}
153+
154+
std::array<int32_t, 3> inline get_normalized_tensor_size(
155+
const Tensor& a,
156+
const int32_t broadcast_dim) {
157+
ET_CHECK_MSG(
158+
a.dim() > broadcast_dim,
159+
"Size of tensor: %zd, must be larger than broadcast_dim: %d",
160+
a.dim(),
161+
broadcast_dim);
162+
std::array<int32_t, 3> normalized_tensor_size;
163+
normalized_tensor_size[0] = 1;
164+
normalized_tensor_size[1] = a.size(broadcast_dim);
165+
normalized_tensor_size[2] = 1;
166+
for (size_t i = 0; i < broadcast_dim; i++) {
167+
normalized_tensor_size[0] *= a.size(i);
168+
}
169+
for (size_t i = broadcast_dim + 1; i < a.dim(); i++) {
170+
normalized_tensor_size[2] *= a.size(i);
171+
}
172+
return normalized_tensor_size;
89173
}
90174

91175
} // namespace executor

kernels/optimized/cpu/op_mul.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,19 @@ Tensor& opt_mul_out(
130130
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
131131
const Tensor* lhs;
132132
const Tensor* rhs;
133-
if (selected_optimized_path ==
134-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
133+
if ((selected_optimized_path ==
134+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
135+
(selected_optimized_path ==
136+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
135137
lhs = &b;
136138
rhs = &a;
137139
} else {
138140
// Catch failure to update logic when adding new broadcasting possibility.
139141
ET_DCHECK(
140-
selected_optimized_path ==
141-
ElementwiseOptimizedPath::kBroadcast2dBy1d);
142+
(selected_optimized_path ==
143+
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
144+
(selected_optimized_path ==
145+
ElementwiseOptimizedPath::kBroadcastNdByNd));
142146
lhs = &a;
143147
rhs = &b;
144148
}
@@ -149,15 +153,34 @@ Tensor& opt_mul_out(
149153
InvalidArgument,
150154
out,
151155
"Failed to resize output tensor.");
156+
int64_t outer_size = 1;
157+
int64_t broadcast_size;
158+
int64_t inner_size;
159+
if ((selected_optimized_path ==
160+
ElementwiseOptimizedPath::kBroadcastNdByNd) ||
161+
(selected_optimized_path ==
162+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
163+
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
164+
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
165+
auto normalized_tensor_size_lhs =
166+
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
167+
outer_size = normalized_tensor_size_lhs[0];
168+
broadcast_size = normalized_tensor_size_lhs[1];
169+
inner_size = normalized_tensor_size_lhs[2];
170+
} else {
171+
broadcast_size = lhs->sizes()[lhs->dim() - 2];
172+
inner_size = lhs->sizes()[lhs->dim() - 1];
173+
}
152174
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
153175
using Vec = executorch::vec::Vectorized<CTYPE>;
154-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
176+
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
155177
[](Vec x, Vec y) { return x * y; },
156178
out.mutable_data_ptr<CTYPE>(),
157179
lhs->const_data_ptr<CTYPE>(),
158180
rhs->const_data_ptr<CTYPE>(),
159-
lhs->sizes()[lhs->dim() - 2],
160-
lhs->sizes()[lhs->dim() - 1]);
181+
outer_size,
182+
broadcast_size,
183+
inner_size);
161184
});
162185
} else {
163186
ScalarType common_type =

kernels/optimized/vec/functional_base.h

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,49 @@ inline void map4(
326326
}
327327

328328

329-
// Map vec_fun across input_data and input_data2, where input_data is
330-
// a two-dimensional array of size (size, size2), input_data2 is a
331-
// one-dimensional array of size size2, and input_data2 is broadcast
332-
// to be of size (size, size2).
329+
// This function implements broadcasting binary operation on two tensors
330+
// where lhs tensor is treated to be of shape [outer_size, broadcast_size, inner_size]
331+
// and rhs tensor is treated to be of shape [outer_size, 1, inner_size]
332+
// And this 1st dimension is considered broadcasting dimension
333+
// This formula can map broadcasting on any dim=broadcast_dim
334+
// for any two N dimensional tensors, where 0 < braodcast_dim < N-1
335+
template <typename scalar_t, typename Op>
336+
inline void broadcasting_map_3d_and_unsqueezed_3d(
337+
const Op& vec_fun,
338+
scalar_t* output_data,
339+
const scalar_t* lhs,
340+
const scalar_t* rhs,
341+
int64_t outer_size,
342+
int64_t broadcast_size,
343+
int64_t inner_size) {
344+
using Vec = vec::Vectorized<scalar_t>;
345+
int64_t outer_stride_lhs = inner_size * broadcast_size;
346+
int64_t outer_stride_rhs = inner_size;
347+
int64_t broadcast_stride_lhs = inner_size;
348+
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
349+
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
350+
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
351+
const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs;
352+
for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) {
353+
const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs;
354+
scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs;
355+
int64_t inner_idx = 0;
356+
for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) {
357+
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx);
358+
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx);
359+
Vec output_vec = vec_fun(data_vec, data_vec2);
360+
output_vec.store(output_data_row_2 + inner_idx);
361+
}
362+
if (inner_size - inner_idx > 0) {
363+
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx);
364+
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx);
365+
Vec output_vec = vec_fun(data_vec, data_vec2);
366+
output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx);
367+
}
368+
}
369+
}
370+
}
371+
333372
template <typename scalar_t, typename Op>
334373
inline void broadcasting_map_2d_by_1d(
335374
const Op& vec_fun,
@@ -338,27 +377,8 @@ inline void broadcasting_map_2d_by_1d(
338377
const scalar_t* input_data2,
339378
int64_t size,
340379
int64_t size2) {
341-
using Vec = vec::Vectorized<scalar_t>;
342-
for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) {
343-
const scalar_t* input_data_row = input_data + outer_idx * size2;
344-
scalar_t* output_data_row = output_data + outer_idx * size2;
345-
int64_t inner_idx = 0;
346-
for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) {
347-
Vec data_vec = Vec::loadu(input_data_row + inner_idx);
348-
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx);
349-
Vec output_vec = vec_fun(data_vec, data_vec2);
350-
output_vec.store(output_data_row + inner_idx);
351-
}
352-
if (size2 - inner_idx > 0) {
353-
Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx);
354-
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx);
355-
Vec output_vec = vec_fun(data_vec, data_vec2);
356-
output_vec.store(output_data_row + inner_idx, size2 - inner_idx);
357-
}
358-
}
380+
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
359381
}
360382

361-
362-
363383
} // namespace vec
364384
} // namespace executorch

kernels/test/op_mul_test.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,73 @@ class OpMulOutTest : public OperatorTest {
153153
}
154154
}
155155

156+
template <ScalarType DTYPE>
157+
void test_broadcast_3D() {
158+
TensorFactory<DTYPE> tf_a;
159+
160+
Tensor a =
161+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
162+
Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7});
163+
164+
// Destination for output of mul.
165+
Tensor out =
166+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
167+
Tensor expected = tf_a.make(
168+
{2, 2, 3}, /*data=*/{2, 6, 12, 8, 15, 24, 35, 48, 63, 50, 66, 84});
169+
170+
// Check that it matches the expected output.
171+
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
172+
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
173+
}
174+
175+
template <ScalarType DTYPE>
176+
void test_broadcast_4D() {
177+
TensorFactory<DTYPE> tf_a;
178+
179+
Tensor a = tf_a.make(
180+
{2, 2, 3, 5},
181+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
182+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
183+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
184+
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
185+
Tensor b = tf_a.make(
186+
{2, 1, 3, 5},
187+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
188+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
189+
190+
// Destination for output of mul.
191+
Tensor out = tf_a.zeros({2, 2, 3, 5});
192+
Tensor expected = tf_a.make(
193+
{2, 2, 3, 5},
194+
/*data=*/{1, 4, 9, 16, 25, 36, 49, 64, 81, 100,
195+
121, 144, 169, 196, 225, 16, 34, 54, 76, 100,
196+
126, 154, 184, 216, 250, 286, 324, 364, 406, 450,
197+
496, 544, 594, 646, 700, 756, 814, 874, 936, 1000,
198+
1066, 1134, 1204, 1276, 1350, 736, 799, 864, 931, 1000,
199+
1071, 1144, 1219, 1296, 1375, 1456, 1539, 1624, 1711, 1800});
200+
201+
// Check that it matches the expected output.
202+
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
203+
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
204+
205+
b = tf_a.make(
206+
{2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
207+
11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
208+
out = tf_a.zeros({2, 2, 3, 5});
209+
expected = tf_a.make(
210+
{2, 2, 3, 5},
211+
/*data=*/{1, 4, 9, 16, 25, 6, 14, 24, 36, 50,
212+
11, 24, 39, 56, 75, 96, 119, 144, 171, 200,
213+
126, 154, 184, 216, 250, 156, 189, 224, 261, 300,
214+
341, 384, 429, 476, 525, 396, 444, 494, 546, 600,
215+
451, 504, 559, 616, 675, 736, 799, 864, 931, 1000,
216+
816, 884, 954, 1026, 1100, 896, 969, 1044, 1121, 1200});
217+
218+
// Check that it matches the expected output.
219+
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
220+
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
221+
}
222+
156223
template <ScalarType DTYPE>
157224
void test_broadcast_b2a() {
158225
TensorFactory<DTYPE> tf_a;
@@ -296,6 +363,16 @@ TEST_F(OpMulOutTest, BroadcastA2BTest) {
296363
test_broadcast_a2b<ScalarType::Int>();
297364
test_broadcast_a2b<ScalarType::Half>();
298365
test_broadcast_a2b<ScalarType::BFloat16>();
366+
367+
// Test 3D tensors
368+
test_broadcast_3D<ScalarType::Float>();
369+
test_broadcast_3D<ScalarType::Half>();
370+
test_broadcast_3D<ScalarType::BFloat16>();
371+
372+
// Test 4D tensors
373+
test_broadcast_4D<ScalarType::Float>();
374+
test_broadcast_4D<ScalarType::Half>();
375+
test_broadcast_4D<ScalarType::BFloat16>();
299376
}
300377

301378
// Broadcast tensor a's size to tensor b's size
@@ -305,6 +382,18 @@ TEST_F(OpMulOutTest, BroadcastB2ATest) {
305382
test_broadcast_b2a<ScalarType::BFloat16>();
306383
}
307384

385+
TEST_F(OpMulOutTest, BroadcastNDTest) {
386+
// Test 3D tensors
387+
test_broadcast_3D<ScalarType::Float>();
388+
test_broadcast_3D<ScalarType::Half>();
389+
test_broadcast_3D<ScalarType::BFloat16>();
390+
391+
// Test 4D tensors
392+
test_broadcast_4D<ScalarType::Float>();
393+
test_broadcast_4D<ScalarType::Half>();
394+
test_broadcast_4D<ScalarType::BFloat16>();
395+
}
396+
308397
// Broadcast tensor a and b's size to a new size c.
309398
TEST_F(OpMulOutTest, BroadcastAB2CTest) {
310399
TensorFactory<ScalarType::Int> tf_a;

0 commit comments

Comments
 (0)