@@ -794,3 +794,112 @@ TEST_F(OpMulScalarOutTest, BFloat16SanityCheck) {
794
794
// Check that it matches the expected output.
795
795
EXPECT_TENSOR_CLOSE (out, tf.make (sizes, {2.6 , 4.2 , 9.2 , 16.4 }));
796
796
}
797
+
798
+ // Tests for broadcast handling fix: when tensor dimensions don't match,
799
+ // the output should be resized to match the tensor with higher dimensionality
800
+ TEST_F (OpMulOutTest, BroadcastDimensionMismatchFix) {
801
+ TensorFactory<ScalarType::Float> tf;
802
+
803
+ // Test case: tensor a of size [6] and b of size [1, 1, 6]
804
+ // Expected output should be [1, 1, 6], not [6]
805
+ Tensor a = tf.make ({6 }, {1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 });
806
+ Tensor b = tf.make ({1 , 1 , 6 }, {2.0 , 2.0 , 2.0 , 2.0 , 2.0 , 2.0 });
807
+
808
+ // Create output tensor with expected broadcast shape [1, 1, 6]
809
+ Tensor out = tf.zeros ({1 , 1 , 6 });
810
+
811
+ // Call the mul function
812
+ Tensor& result = op_mul_out (a, b, out);
813
+
814
+ // Verify the output shape is [1, 1, 6]
815
+ EXPECT_EQ (result.dim (), 3 );
816
+ EXPECT_EQ (result.size (0 ), 1 );
817
+ EXPECT_EQ (result.size (1 ), 1 );
818
+ EXPECT_EQ (result.size (2 ), 6 );
819
+
820
+ // Verify the values are correct (element-wise multiplication with
821
+ // broadcasting)
822
+ Tensor expected = tf.make ({1 , 1 , 6 }, {2.0 , 4.0 , 6.0 , 8.0 , 10.0 , 12.0 });
823
+ EXPECT_TENSOR_CLOSE (result, expected);
824
+ }
825
+
826
+ TEST_F (OpMulOutTest, BroadcastDimensionMismatchReversed) {
827
+ TensorFactory<ScalarType::Float> tf;
828
+
829
+ // Test case: tensor a of size [1, 1, 6] and b of size [6]
830
+ // Expected output should be [1, 1, 6]
831
+ Tensor a = tf.make ({1 , 1 , 6 }, {1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 });
832
+ Tensor b = tf.make ({6 }, {2.0 , 2.0 , 2.0 , 2.0 , 2.0 , 2.0 });
833
+
834
+ // Create output tensor with expected broadcast shape [1, 1, 6]
835
+ Tensor out = tf.zeros ({1 , 1 , 6 });
836
+
837
+ // Call the mul function
838
+ Tensor& result = op_mul_out (a, b, out);
839
+
840
+ // Verify the output shape is [1, 1, 6]
841
+ EXPECT_EQ (result.dim (), 3 );
842
+ EXPECT_EQ (result.size (0 ), 1 );
843
+ EXPECT_EQ (result.size (1 ), 1 );
844
+ EXPECT_EQ (result.size (2 ), 6 );
845
+
846
+ // Verify the values are correct (element-wise multiplication with
847
+ // broadcasting)
848
+ Tensor expected = tf.make ({1 , 1 , 6 }, {2.0 , 4.0 , 6.0 , 8.0 , 10.0 , 12.0 });
849
+ EXPECT_TENSOR_CLOSE (result, expected);
850
+ }
851
+
852
+ TEST_F (OpMulOutTest, BroadcastDimensionMismatchWithDifferentTypes) {
853
+ // Test the same broadcast fix with different data types
854
+ TensorFactory<ScalarType::Half> tf_half;
855
+ TensorFactory<ScalarType::BFloat16> tf_bf16;
856
+ TensorFactory<ScalarType::Int> tf_int;
857
+
858
+ // Test with Half precision
859
+ {
860
+ Tensor a = tf_half.make ({4 }, {1.0 , 2.0 , 3.0 , 4.0 });
861
+ Tensor b = tf_half.make ({1 , 1 , 4 }, {2.0 , 2.0 , 2.0 , 2.0 });
862
+ Tensor out = tf_half.zeros ({1 , 1 , 4 });
863
+
864
+ Tensor& result = op_mul_out (a, b, out);
865
+ EXPECT_EQ (result.dim (), 3 );
866
+ EXPECT_EQ (result.size (0 ), 1 );
867
+ EXPECT_EQ (result.size (1 ), 1 );
868
+ EXPECT_EQ (result.size (2 ), 4 );
869
+
870
+ Tensor expected = tf_half.make ({1 , 1 , 4 }, {2.0 , 4.0 , 6.0 , 8.0 });
871
+ EXPECT_TENSOR_CLOSE (result, expected);
872
+ }
873
+
874
+ // Test with BFloat16
875
+ {
876
+ Tensor a = tf_bf16.make ({4 }, {1.0 , 2.0 , 3.0 , 4.0 });
877
+ Tensor b = tf_bf16.make ({1 , 1 , 4 }, {2.0 , 2.0 , 2.0 , 2.0 });
878
+ Tensor out = tf_bf16.zeros ({1 , 1 , 4 });
879
+
880
+ Tensor& result = op_mul_out (a, b, out);
881
+ EXPECT_EQ (result.dim (), 3 );
882
+ EXPECT_EQ (result.size (0 ), 1 );
883
+ EXPECT_EQ (result.size (1 ), 1 );
884
+ EXPECT_EQ (result.size (2 ), 4 );
885
+
886
+ Tensor expected = tf_bf16.make ({1 , 1 , 4 }, {2.0 , 4.0 , 6.0 , 8.0 });
887
+ EXPECT_TENSOR_CLOSE (result, expected);
888
+ }
889
+
890
+ // Test with Int
891
+ {
892
+ Tensor a = tf_int.make ({4 }, {1 , 2 , 3 , 4 });
893
+ Tensor b = tf_int.make ({1 , 1 , 4 }, {2 , 2 , 2 , 2 });
894
+ Tensor out = tf_int.zeros ({1 , 1 , 4 });
895
+
896
+ Tensor& result = op_mul_out (a, b, out);
897
+ EXPECT_EQ (result.dim (), 3 );
898
+ EXPECT_EQ (result.size (0 ), 1 );
899
+ EXPECT_EQ (result.size (1 ), 1 );
900
+ EXPECT_EQ (result.size (2 ), 4 );
901
+
902
+ Tensor expected = tf_int.make ({1 , 1 , 4 }, {2 , 4 , 6 , 8 });
903
+ EXPECT_TENSOR_EQ (result, expected);
904
+ }
905
+ }
0 commit comments