@@ -112,6 +112,122 @@ class OpAddOutKernelTest : public OperatorTest {
112
112
// tests.
113
113
EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {2.5 , 3.5 , 5.75 , 10.125 }));
114
114
}
115
+
116
+ template <ScalarType DTYPE>
117
+ void test_broadcast_3D () {
118
+ TensorFactory<DTYPE> tf_a;
119
+
120
+ Tensor a =
121
+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
122
+ Tensor b = tf_a.make ({2 , 1 , 3 }, /* data=*/ {2 , 3 , 4 , 5 , 6 , 7 });
123
+
124
+ // Destination for output of mul.
125
+ Tensor out =
126
+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
127
+ Tensor expected = tf_a.make (
128
+ {2 , 2 , 3 }, /* data=*/ {3 , 5 , 7 , 6 , 8 , 10 , 12 , 14 , 16 , 15 , 17 , 19 });
129
+
130
+ // Check that it matches the expected output.
131
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
132
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
133
+ }
134
+
135
+ template <ScalarType DTYPE>
136
+ void test_broadcast_4D () {
137
+ TensorFactory<DTYPE> tf_a;
138
+
139
+ Tensor a = tf_a.make (
140
+ {2 , 2 , 3 , 5 },
141
+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
142
+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
143
+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
144
+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
145
+ Tensor b = tf_a.make (
146
+ {2 , 1 , 3 , 5 },
147
+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
148
+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 });
149
+
150
+ // Destination for output of mul.
151
+ Tensor out = tf_a.zeros ({2 , 2 , 3 , 5 });
152
+ Tensor expected = tf_a.make (
153
+ {2 , 2 , 3 , 5 },
154
+ /* data=*/ {2 , 4 , 6 , 8 , 10 , 12 , 14 , 16 , 18 , 20 , 22 , 24 , 26 , 28 , 30 ,
155
+ 17 , 19 , 21 , 23 , 25 , 27 , 29 , 31 , 33 , 35 , 37 , 39 , 41 , 43 , 45 ,
156
+ 47 , 49 , 51 , 53 , 55 , 57 , 59 , 61 , 63 , 65 , 67 , 69 , 71 , 73 , 75 ,
157
+ 62 , 64 , 66 , 68 , 70 , 72 , 74 , 76 , 78 , 80 , 82 , 84 , 86 , 88 , 90 });
158
+
159
+ // Check that it matches the expected output.
160
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
161
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
162
+
163
+ b = tf_a.make (
164
+ {2 , 2 , 1 , 5 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ,
165
+ 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 });
166
+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
167
+ expected = tf_a.make (
168
+ {2 , 2 , 3 , 5 },
169
+ /* data=*/ {2 , 4 , 6 , 8 , 10 , 7 , 9 , 11 , 13 , 15 , 12 , 14 , 16 , 18 , 20 ,
170
+ 22 , 24 , 26 , 28 , 30 , 27 , 29 , 31 , 33 , 35 , 32 , 34 , 36 , 38 , 40 ,
171
+ 42 , 44 , 46 , 48 , 50 , 47 , 49 , 51 , 53 , 55 , 52 , 54 , 56 , 58 , 60 ,
172
+ 62 , 64 , 66 , 68 , 70 , 67 , 69 , 71 , 73 , 75 , 72 , 74 , 76 , 78 , 80 });
173
+
174
+ // Check that it matches the expected output.
175
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
176
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
177
+ }
178
+
179
+ template <ScalarType DTYPE>
180
+ void test_broadcast_last_dim () {
181
+ TensorFactory<DTYPE> tf_a;
182
+
183
+ Tensor a =
184
+ tf_a.make ({4 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
185
+ Tensor b = tf_a.make ({4 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
186
+
187
+ // Destination for output of mul.
188
+ Tensor out = tf_a.zeros ({4 , 3 });
189
+ Tensor expected =
190
+ tf_a.make ({4 , 3 }, /* data=*/ {3 , 4 , 5 , 7 , 8 , 9 , 11 , 12 , 13 , 15 , 16 , 17 });
191
+
192
+ // Check that it matches the expected output.
193
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
194
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
195
+
196
+ a = tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
197
+ b = tf_a.make ({2 , 2 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
198
+
199
+ // Destination for output of mul.
200
+ out = tf_a.zeros ({2 , 2 , 3 });
201
+ expected = tf_a.make (
202
+ {2 , 2 , 3 }, /* data=*/ {3 , 4 , 5 , 7 , 8 , 9 , 11 , 12 , 13 , 15 , 16 , 17 });
203
+
204
+ // Check that it matches the expected output.
205
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
206
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
207
+
208
+ a = tf_a.make (
209
+ {2 , 2 , 3 , 5 },
210
+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
211
+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
212
+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
213
+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
214
+ b = tf_a.make (
215
+ {2 , 2 , 3 , 1 },
216
+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
217
+
218
+ // Destination for output of mul.
219
+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
220
+ expected = tf_a.make (
221
+ {2 , 2 , 3 , 5 },
222
+ /* data=*/ {2 , 3 , 4 , 5 , 6 , 8 , 9 , 10 , 11 , 12 , 14 , 15 , 16 , 17 , 18 ,
223
+ 20 , 21 , 22 , 23 , 24 , 26 , 27 , 28 , 29 , 30 , 32 , 33 , 34 , 35 , 36 ,
224
+ 38 , 39 , 40 , 41 , 42 , 44 , 45 , 46 , 47 , 48 , 50 , 51 , 52 , 53 , 54 ,
225
+ 56 , 57 , 58 , 59 , 60 , 62 , 63 , 64 , 65 , 66 , 68 , 69 , 70 , 71 , 72 });
226
+
227
+ // Check that it matches the expected output.
228
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
229
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
230
+ }
115
231
};
116
232
117
233
class OpAddScalarOutKernelTest : public OperatorTest {
@@ -371,6 +487,23 @@ TEST_F(OpAddOutKernelTest, BroadcastOneElementRank0Tensor) {
371
487
EXPECT_TENSOR_EQ (out, ret);
372
488
}
373
489
490
+ TEST_F (OpAddOutKernelTest, BroadcastNDTest) {
491
+ // Test 3D tensors
492
+ test_broadcast_3D<ScalarType::Float>();
493
+ test_broadcast_3D<ScalarType::Half>();
494
+ test_broadcast_3D<ScalarType::BFloat16>();
495
+
496
+ // Test 4D tensors
497
+ test_broadcast_4D<ScalarType::Float>();
498
+ test_broadcast_4D<ScalarType::Half>();
499
+ test_broadcast_4D<ScalarType::BFloat16>();
500
+
501
+ // Test broadcasting on the last dimension
502
+ test_broadcast_last_dim<ScalarType::Float>();
503
+ test_broadcast_last_dim<ScalarType::Half>();
504
+ test_broadcast_last_dim<ScalarType::BFloat16>();
505
+ }
506
+
374
507
//
375
508
// Death Tests
376
509
//
0 commit comments