Skip to content

Commit 1ae997c

Browse files
[executorch] Ignore leading 1 dimensions when checking optimized path for op_mul (#4963)
A 1 x 1 x ... x m x n tensor can be element-wise multiplied with a m x n tensor just fine. Pull Request resolved: #4806 Co-authored-by: Scott Wolchok <[email protected]>
1 parent 2553b85 commit 1ae997c

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

kernels/optimized/cpu/op_mul.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,25 @@ using ScalarType = exec_aten::ScalarType;
2222

2323
namespace {
2424

25+
bool sizes_match_ignoring_leading_1s(
26+
ArrayRef<Tensor::SizesType> lhs,
27+
ArrayRef<Tensor::SizesType> rhs) {
28+
auto lhs_begin = lhs.begin();
29+
auto lhs_end = lhs.end();
30+
while (lhs_begin != lhs_end && *lhs_begin == 1) {
31+
++lhs_begin;
32+
}
33+
34+
auto rhs_begin = rhs.begin();
35+
auto rhs_end = rhs.end();
36+
while (rhs_begin != rhs_end && *rhs_begin == 1) {
37+
++rhs_begin;
38+
}
39+
40+
return ((lhs_end - lhs_begin) == (rhs_end - rhs_begin)) &&
41+
std::equal(lhs_begin, lhs_end, rhs_begin);
42+
}
43+
2544
// Move to generic util as this is applicable to all binary ops
2645
bool can_use_optimized_path(
2746
const Tensor& a,
@@ -38,7 +57,9 @@ bool can_use_optimized_path(
3857
(a_type != ScalarType::Half && b_type != ScalarType::Half);
3958
can_use_optimized_path = can_use_optimized_path &&
4059
(a.sizes().equals(b.sizes()) ||
41-
(a.numel() == b.numel() && a.numel() == out.numel()));
60+
(a.numel() == b.numel() &&
61+
(a.numel() == out.numel() ||
62+
sizes_match_ignoring_leading_1s(a.sizes(), b.sizes()))));
4263
return can_use_optimized_path;
4364
}
4465

kernels/test/op_mul_test.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,21 @@ TEST_F(OpMulOutTest, BoolTensors) {
165165
EXPECT_TENSOR_EQ(out, tf.make(sizes, /*data=*/{false, false, true, false}));
166166
}
167167

168+
TEST_F(OpMulOutTest, OptimizedPathIgnoresLeading1Dimensions) {
169+
TensorFactory<ScalarType::Float> tf;
170+
171+
const std::vector<int32_t> sizes1 = {1, 1, 2, 2};
172+
const std::vector<int32_t> sizes2 = {1, 2, 2};
173+
174+
// Destination for the mul.
175+
Tensor out = tf.zeros(sizes1);
176+
177+
// Multiply two tensors
178+
op_mul_out(
179+
tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.ones(sizes2), out);
180+
EXPECT_TENSOR_CLOSE(out, tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}));
181+
}
182+
168183
// Mismatched shape tests.
169184
TEST_F(OpMulOutTest, MismatchedInputShapesDies) {
170185
if (SupportedFeatures::get()->is_aten) {

0 commit comments

Comments
 (0)