@@ -112,6 +112,125 @@ class OpAddOutKernelTest : public OperatorTest {
112
112
// tests.
113
113
EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {2.5 , 3.5 , 5.75 , 10.125 }));
114
114
}
115
+
116
+ template <ScalarType DTYPE>
117
+ void test_broadcast_3D () {
118
+ TensorFactory<DTYPE> tf_a;
119
+
120
+ Tensor a =
121
+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
122
+ Tensor b = tf_a.make ({2 , 1 , 3 }, /* data=*/ {2 , 3 , 4 , 5 , 6 , 7 });
123
+
124
+ // Destination for output of mul.
125
+ Tensor out =
126
+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
127
+ Tensor expected = tf_a.make (
128
+ {2 , 2 , 3 }, /* data=*/ {3 , 5 , 7 , 6 , 8 , 10 , 12 , 14 , 16 , 15 , 17 , 19 });
129
+
130
+ // Check that it matches the expected output.
131
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
132
+ expected = tf_a.make (
133
+ {2 , 2 , 3 },
134
+ /* data=*/ {3.5 , 6 , 8.5 , 8 , 10.5 , 13 , 15.5 , 18 , 20.5 , 20 , 22.5 , 25 });
135
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.5 , out), expected);
136
+ }
137
+
138
+ template <ScalarType DTYPE>
139
+ void test_broadcast_4D () {
140
+ TensorFactory<DTYPE> tf_a;
141
+
142
+ Tensor a = tf_a.make (
143
+ {2 , 2 , 3 , 5 },
144
+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
145
+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
146
+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
147
+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
148
+ Tensor b = tf_a.make (
149
+ {2 , 1 , 3 , 5 },
150
+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
151
+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 });
152
+
153
+ // Destination for output of mul.
154
+ Tensor out = tf_a.zeros ({2 , 2 , 3 , 5 });
155
+ Tensor expected = tf_a.make (
156
+ {2 , 2 , 3 , 5 },
157
+ /* data=*/ {2 , 4 , 6 , 8 , 10 , 12 , 14 , 16 , 18 , 20 , 22 , 24 , 26 , 28 , 30 ,
158
+ 17 , 19 , 21 , 23 , 25 , 27 , 29 , 31 , 33 , 35 , 37 , 39 , 41 , 43 , 45 ,
159
+ 47 , 49 , 51 , 53 , 55 , 57 , 59 , 61 , 63 , 65 , 67 , 69 , 71 , 73 , 75 ,
160
+ 62 , 64 , 66 , 68 , 70 , 72 , 74 , 76 , 78 , 80 , 82 , 84 , 86 , 88 , 90 });
161
+
162
+ // Check that it matches the expected output.
163
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
164
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
165
+
166
+ b = tf_a.make (
167
+ {2 , 2 , 1 , 5 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ,
168
+ 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 });
169
+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
170
+ expected = tf_a.make (
171
+ {2 , 2 , 3 , 5 },
172
+ /* data=*/ {2 , 4 , 6 , 8 , 10 , 7 , 9 , 11 , 13 , 15 , 12 , 14 , 16 , 18 , 20 ,
173
+ 22 , 24 , 26 , 28 , 30 , 27 , 29 , 31 , 33 , 35 , 32 , 34 , 36 , 38 , 40 ,
174
+ 42 , 44 , 46 , 48 , 50 , 47 , 49 , 51 , 53 , 55 , 52 , 54 , 56 , 58 , 60 ,
175
+ 62 , 64 , 66 , 68 , 70 , 67 , 69 , 71 , 73 , 75 , 72 , 74 , 76 , 78 , 80 });
176
+
177
+ // Check that it matches the expected output.
178
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
179
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
180
+ }
181
+
182
+ template <ScalarType DTYPE>
183
+ void test_broadcast_last_dim () {
184
+ TensorFactory<DTYPE> tf_a;
185
+
186
+ Tensor a =
187
+ tf_a.make ({4 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
188
+ Tensor b = tf_a.make ({4 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
189
+
190
+ // Destination for output of mul.
191
+ Tensor out = tf_a.zeros ({4 , 3 });
192
+ Tensor expected =
193
+ tf_a.make ({4 , 3 }, /* data=*/ {3 , 4 , 5 , 7 , 8 , 9 , 11 , 12 , 13 , 15 , 16 , 17 });
194
+
195
+ // Check that it matches the expected output.
196
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
197
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
198
+
199
+ a = tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
200
+ b = tf_a.make ({2 , 2 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
201
+
202
+ // Destination for output of mul.
203
+ out = tf_a.zeros ({2 , 2 , 3 });
204
+ expected = tf_a.make (
205
+ {2 , 2 , 3 }, /* data=*/ {3 , 4 , 5 , 7 , 8 , 9 , 11 , 12 , 13 , 15 , 16 , 17 });
206
+
207
+ // Check that it matches the expected output.
208
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
209
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
210
+
211
+ a = tf_a.make (
212
+ {2 , 2 , 3 , 5 },
213
+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
214
+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
215
+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
216
+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
217
+ b = tf_a.make (
218
+ {2 , 2 , 3 , 1 },
219
+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
220
+
221
+ // Destination for output of mul.
222
+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
223
+ expected = tf_a.make (
224
+ {2 , 2 , 3 , 5 },
225
+ /* data=*/ {2 , 3 , 4 , 5 , 6 , 8 , 9 , 10 , 11 , 12 , 14 , 15 , 16 , 17 , 18 ,
226
+ 20 , 21 , 22 , 23 , 24 , 26 , 27 , 28 , 29 , 30 , 32 , 33 , 34 , 35 , 36 ,
227
+ 38 , 39 , 40 , 41 , 42 , 44 , 45 , 46 , 47 , 48 , 50 , 51 , 52 , 53 , 54 ,
228
+ 56 , 57 , 58 , 59 , 60 , 62 , 63 , 64 , 65 , 66 , 68 , 69 , 70 , 71 , 72 });
229
+
230
+ // Check that it matches the expected output.
231
+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
232
+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
233
+ }
115
234
};
116
235
117
236
class OpAddScalarOutKernelTest : public OperatorTest {
@@ -371,6 +490,23 @@ TEST_F(OpAddOutKernelTest, BroadcastOneElementRank0Tensor) {
371
490
EXPECT_TENSOR_EQ (out, ret);
372
491
}
373
492
493
+ TEST_F (OpAddOutKernelTest, BroadcastNDTest) {
494
+ // Test 3D tensors
495
+ test_broadcast_3D<ScalarType::Float>();
496
+ test_broadcast_3D<ScalarType::Half>();
497
+ test_broadcast_3D<ScalarType::BFloat16>();
498
+
499
+ // Test 4D tensors
500
+ test_broadcast_4D<ScalarType::Float>();
501
+ test_broadcast_4D<ScalarType::Half>();
502
+ test_broadcast_4D<ScalarType::BFloat16>();
503
+
504
+ // Test broadcasting on the last dimension
505
+ test_broadcast_last_dim<ScalarType::Float>();
506
+ test_broadcast_last_dim<ScalarType::Half>();
507
+ test_broadcast_last_dim<ScalarType::BFloat16>();
508
+ }
509
+
374
510
//
375
511
// Death Tests
376
512
//
0 commit comments