@@ -99,6 +99,109 @@ class OpSubOutTest : public OperatorTest {
99
99
EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {0.1 , 1.2 , 3.4 , 7.8 }));
100
100
}
101
101
102
+ template <ScalarType DTYPE>
103
+ void test_broadcast_3D () {
104
+ TensorFactory<DTYPE> tf_a;
105
+
106
+ Tensor a =
107
+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
108
+ Tensor b = tf_a.make ({2 , 1 , 3 }, /* data=*/ {2 , 3 , 4 , 5 , 6 , 7 });
109
+
110
+ // Destination for output of mul.
111
+ Tensor out =
112
+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
113
+ Tensor expected =
114
+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {-1 , -1 , -1 , 2 , 2 , 2 , 2 , 2 , 2 , 5 , 5 , 5 });
115
+
116
+ // Check that it matches the expected output.
117
+ EXPECT_TENSOR_CLOSE (op_sub_out (a, b, 1.0 , out), expected);
118
+ // b - a * 1.5 output should be
119
+ expected = tf_a.make (
120
+ {2 , 2 , 3 },
121
+ /* data=*/
122
+ {0.5 ,
123
+ 0.0 ,
124
+ -0.5 ,
125
+ -4.0 ,
126
+ -4.5 ,
127
+ -5.0 ,
128
+ -5.5 ,
129
+ -6.0 ,
130
+ -6.5 ,
131
+ -10.0 ,
132
+ -10.5 ,
133
+ -11.0 });
134
+ EXPECT_TENSOR_CLOSE (op_sub_out (b, a, 1.5 , out), expected);
135
+ }
136
+
137
+ template <ScalarType DTYPE>
138
+ void test_broadcast_4D () {
139
+ TensorFactory<DTYPE> tf_a;
140
+
141
+ Tensor a = tf_a.make (
142
+ {2 , 2 , 3 , 5 },
143
+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
144
+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
145
+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
146
+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
147
+ Tensor b = tf_a.make (
148
+ {2 , 1 , 3 , 5 },
149
+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
150
+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 });
151
+
152
+ // Destination for output of mul.
153
+ Tensor out = tf_a.zeros ({2 , 2 , 3 , 5 });
154
+ Tensor expected = tf_a.make (
155
+ {2 , 2 , 3 , 5 },
156
+ /* data=*/ {0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
157
+ 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 ,
158
+ 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 ,
159
+ 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 });
160
+
161
+ // Check that it matches the expected output.
162
+ EXPECT_TENSOR_CLOSE (op_sub_out (a, b, 1.0 , out), expected);
163
+ expected = tf_a.make (
164
+ {2 , 2 , 3 , 5 },
165
+ /* data=*/ {0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
166
+ 0 , 0 , 0 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 ,
167
+ -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 ,
168
+ -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -30 , -30 , -30 ,
169
+ -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 });
170
+ EXPECT_TENSOR_CLOSE (op_sub_out (b, a, 1.0 , out), expected);
171
+
172
+ b = tf_a.make (
173
+ {2 , 2 , 1 , 5 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ,
174
+ 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 });
175
+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
176
+ expected = tf_a.make (
177
+ {2 , 2 , 3 , 5 },
178
+ /* data=*/ {0 , 0 , 0 , 0 , 0 , 5 , 5 , 5 , 5 , 5 , 10 , 10 , 10 , 10 , 10 ,
179
+ 10 , 10 , 10 , 10 , 10 , 15 , 15 , 15 , 15 , 15 , 20 , 20 , 20 , 20 , 20 ,
180
+ 20 , 20 , 20 , 20 , 20 , 25 , 25 , 25 , 25 , 25 , 30 , 30 , 30 , 30 , 30 ,
181
+ 30 , 30 , 30 , 30 , 30 , 35 , 35 , 35 , 35 , 35 , 40 , 40 , 40 , 40 , 40 });
182
+
183
+ // Check that it matches the expected output.
184
+ EXPECT_TENSOR_CLOSE (op_sub_out (a, b, 1.0 , out), expected);
185
+ expected = tf_a.make (
186
+ {2 , 2 , 3 , 5 },
187
+ /* data=*/ {-0.5000 , -1.0000 , -1.5000 , -2.0000 , -2.5000 ,
188
+ -8.0000 , -8.5000 , -9.0000 , -9.5000 , -10.0000 ,
189
+ -15.5000 , -16.0000 , -16.5000 , -17.0000 , -17.5000 ,
190
+
191
+ -18.0000 , -18.5000 , -19.0000 , -19.5000 , -20.0000 ,
192
+ -25.5000 , -26.0000 , -26.5000 , -27.0000 , -27.5000 ,
193
+ -33.0000 , -33.5000 , -34.0000 , -34.5000 , -35.0000 ,
194
+
195
+ -35.5000 , -36.0000 , -36.5000 , -37.0000 , -37.5000 ,
196
+ -43.0000 , -43.5000 , -44.0000 , -44.5000 , -45.0000 ,
197
+ -50.5000 , -51.0000 , -51.5000 , -52.0000 , -52.5000 ,
198
+
199
+ -53.0000 , -53.5000 , -54.0000 , -54.5000 , -55.0000 ,
200
+ -60.5000 , -61.0000 , -61.5000 , -62.0000 , -62.5000 ,
201
+ -68.0000 , -68.5000 , -69.0000 , -69.5000 , -70.0000 });
202
+ EXPECT_TENSOR_CLOSE (op_sub_out (b, a, 1.5 , out), expected);
203
+ }
204
+
102
205
void test_sub_enumerate_a_types () {
103
206
#define ENUMERATE_TEST_ENTRY (ctype, dtype ) \
104
207
test_sub_enumerate_b_types<ScalarType::dtype>();
@@ -237,6 +340,19 @@ TEST_F(OpSubOutTest, BroadcastScalarRank0Supported) {
237
340
EXPECT_TENSOR_EQ (out, ret);
238
341
}
239
342
343
+ TEST_F (OpSubOutTest, BroadcastNDTest) {
344
+ // Test 3D tensors
345
+ test_broadcast_3D<ScalarType::Float>();
346
+ test_broadcast_3D<ScalarType::Half>();
347
+ // Sub doesnt yet support BFloat16
348
+ // test_broadcast_3D<ScalarType::BFloat16>();
349
+
350
+ // Test 4D tensors
351
+ test_broadcast_4D<ScalarType::Float>();
352
+ test_broadcast_4D<ScalarType::Half>();
353
+ // test_broadcast_4D<ScalarType::BFloat16>();
354
+ }
355
+
240
356
//
241
357
// Death Tests
242
358
//
0 commit comments