Skip to content

Commit 0536862

Browse files
authored
Fix bug in optimized mul's broadcast handling (#11590)
Summary: When we have two tensors that match exactly in size but not in dims, the output resize does not work correctly in cases like this a[6] b[1, 1, 6] -> out [1, 1, 6] but current code tried to resize using a.sizes() Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: ### Summary [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #<issue-id>` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: <area>" label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). ### Test plan [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.
1 parent 0d3e750 commit 0536862

File tree

2 files changed

+112
-6
lines changed

2 files changed

+112
-6
lines changed

kernels/optimized/cpu/op_mul.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,11 @@ Tensor& opt_mul_out(
111111

112112
auto selected_optimized_path = select_optimized_path(a, b, out);
113113
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
114-
// Resize for dynamic shape
115-
auto error = resize_tensor(out, a.sizes());
116-
ET_KERNEL_CHECK_MSG(
114+
ET_KERNEL_CHECK(
117115
ctx,
118-
error == Error::Ok,
116+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
119117
InvalidArgument,
120-
out,
121-
"Failed to resize output tensor.");
118+
out);
122119

123120
if (executorch::runtime::isComplexType(out_type)) {
124121
ET_KERNEL_CHECK(

kernels/test/op_mul_test.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,3 +794,112 @@ TEST_F(OpMulScalarOutTest, BFloat16SanityCheck) {
794794
// Check that it matches the expected output.
795795
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {2.6, 4.2, 9.2, 16.4}));
796796
}
797+
798+
// Tests for broadcast handling fix: when tensor dimensions don't match,
799+
// the output should be resized to match the tensor with higher dimensionality
800+
TEST_F(OpMulOutTest, BroadcastDimensionMismatchFix) {
801+
TensorFactory<ScalarType::Float> tf;
802+
803+
// Test case: tensor a of size [6] and b of size [1, 1, 6]
804+
// Expected output should be [1, 1, 6], not [6]
805+
Tensor a = tf.make({6}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
806+
Tensor b = tf.make({1, 1, 6}, {2.0, 2.0, 2.0, 2.0, 2.0, 2.0});
807+
808+
// Create output tensor with expected broadcast shape [1, 1, 6]
809+
Tensor out = tf.zeros({1, 1, 6});
810+
811+
// Call the mul function
812+
Tensor& result = op_mul_out(a, b, out);
813+
814+
// Verify the output shape is [1, 1, 6]
815+
EXPECT_EQ(result.dim(), 3);
816+
EXPECT_EQ(result.size(0), 1);
817+
EXPECT_EQ(result.size(1), 1);
818+
EXPECT_EQ(result.size(2), 6);
819+
820+
// Verify the values are correct (element-wise multiplication with
821+
// broadcasting)
822+
Tensor expected = tf.make({1, 1, 6}, {2.0, 4.0, 6.0, 8.0, 10.0, 12.0});
823+
EXPECT_TENSOR_CLOSE(result, expected);
824+
}
825+
826+
TEST_F(OpMulOutTest, BroadcastDimensionMismatchReversed) {
827+
TensorFactory<ScalarType::Float> tf;
828+
829+
// Test case: tensor a of size [1, 1, 6] and b of size [6]
830+
// Expected output should be [1, 1, 6]
831+
Tensor a = tf.make({1, 1, 6}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
832+
Tensor b = tf.make({6}, {2.0, 2.0, 2.0, 2.0, 2.0, 2.0});
833+
834+
// Create output tensor with expected broadcast shape [1, 1, 6]
835+
Tensor out = tf.zeros({1, 1, 6});
836+
837+
// Call the mul function
838+
Tensor& result = op_mul_out(a, b, out);
839+
840+
// Verify the output shape is [1, 1, 6]
841+
EXPECT_EQ(result.dim(), 3);
842+
EXPECT_EQ(result.size(0), 1);
843+
EXPECT_EQ(result.size(1), 1);
844+
EXPECT_EQ(result.size(2), 6);
845+
846+
// Verify the values are correct (element-wise multiplication with
847+
// broadcasting)
848+
Tensor expected = tf.make({1, 1, 6}, {2.0, 4.0, 6.0, 8.0, 10.0, 12.0});
849+
EXPECT_TENSOR_CLOSE(result, expected);
850+
}
851+
852+
TEST_F(OpMulOutTest, BroadcastDimensionMismatchWithDifferentTypes) {
853+
// Test the same broadcast fix with different data types
854+
TensorFactory<ScalarType::Half> tf_half;
855+
TensorFactory<ScalarType::BFloat16> tf_bf16;
856+
TensorFactory<ScalarType::Int> tf_int;
857+
858+
// Test with Half precision
859+
{
860+
Tensor a = tf_half.make({4}, {1.0, 2.0, 3.0, 4.0});
861+
Tensor b = tf_half.make({1, 1, 4}, {2.0, 2.0, 2.0, 2.0});
862+
Tensor out = tf_half.zeros({1, 1, 4});
863+
864+
Tensor& result = op_mul_out(a, b, out);
865+
EXPECT_EQ(result.dim(), 3);
866+
EXPECT_EQ(result.size(0), 1);
867+
EXPECT_EQ(result.size(1), 1);
868+
EXPECT_EQ(result.size(2), 4);
869+
870+
Tensor expected = tf_half.make({1, 1, 4}, {2.0, 4.0, 6.0, 8.0});
871+
EXPECT_TENSOR_CLOSE(result, expected);
872+
}
873+
874+
// Test with BFloat16
875+
{
876+
Tensor a = tf_bf16.make({4}, {1.0, 2.0, 3.0, 4.0});
877+
Tensor b = tf_bf16.make({1, 1, 4}, {2.0, 2.0, 2.0, 2.0});
878+
Tensor out = tf_bf16.zeros({1, 1, 4});
879+
880+
Tensor& result = op_mul_out(a, b, out);
881+
EXPECT_EQ(result.dim(), 3);
882+
EXPECT_EQ(result.size(0), 1);
883+
EXPECT_EQ(result.size(1), 1);
884+
EXPECT_EQ(result.size(2), 4);
885+
886+
Tensor expected = tf_bf16.make({1, 1, 4}, {2.0, 4.0, 6.0, 8.0});
887+
EXPECT_TENSOR_CLOSE(result, expected);
888+
}
889+
890+
// Test with Int
891+
{
892+
Tensor a = tf_int.make({4}, {1, 2, 3, 4});
893+
Tensor b = tf_int.make({1, 1, 4}, {2, 2, 2, 2});
894+
Tensor out = tf_int.zeros({1, 1, 4});
895+
896+
Tensor& result = op_mul_out(a, b, out);
897+
EXPECT_EQ(result.dim(), 3);
898+
EXPECT_EQ(result.size(0), 1);
899+
EXPECT_EQ(result.size(1), 1);
900+
EXPECT_EQ(result.size(2), 4);
901+
902+
Tensor expected = tf_int.make({1, 1, 4}, {2, 4, 6, 8});
903+
EXPECT_TENSOR_EQ(result, expected);
904+
}
905+
}

0 commit comments

Comments
 (0)