@@ -72,7 +72,7 @@ class OpMulOutTest : public OperatorTest {
72
72
#define ENUMERATE_TEST_ENTRY (ctype, dtype ) \
73
73
test_mul_enumerate_out_types<DTYPE_A, ScalarType::dtype>();
74
74
75
- ET_FORALL_REAL_TYPES_AND (Half, ENUMERATE_TEST_ENTRY)
75
+ ET_FORALL_REALHBF16_TYPES ( ENUMERATE_TEST_ENTRY)
76
76
77
77
#undef ENUMERATE_TEST_ENTRY
78
78
}
@@ -89,29 +89,99 @@ class OpMulOutTest : public OperatorTest {
89
89
90
90
// Multiply two tensors
91
91
op_mul_out (
92
- tf.make (sizes, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }), tf.ones (sizes), out);
93
- EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }));
92
+ tf.make (sizes, /* data=*/ {1.25 , 2.5 , 4.75 , 8.875 }), tf.ones (sizes), out);
93
+ EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {1.25 , 2.5 , 4.75 , 8.875 }));
94
94
95
95
op_mul_out (
96
96
tf.make (sizes, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }), tf.zeros (sizes), out);
97
97
EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {0.0 , 0.0 , 0.0 , 0.0 }));
98
98
99
99
op_mul_out (
100
- tf.make (sizes, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }),
101
- tf.make (sizes, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }),
100
+ tf.make (sizes, /* data=*/ {1.25 , 2.5 , 4.75 , 8.875 }),
101
+ tf.make (sizes, /* data=*/ {1.25 , 2.5 , 4.75 , 8.875 }),
102
102
out);
103
103
EXPECT_TENSOR_CLOSE (
104
- out, tf.make (sizes, /* data=*/ {1.21 , 4.84 , 19.36 , 77.44 }));
104
+ out, tf.make (sizes, /* data=*/ {1.5625 , 6.25 , 22.5625 , 78.765625 }));
105
105
}
106
106
107
107
void test_mul_enumerate_a_types () {
108
108
#define ENUMERATE_TEST_ENTRY (ctype, dtype ) \
109
109
test_mul_enumerate_b_types<ScalarType::dtype>();
110
110
111
- ET_FORALL_REAL_TYPES_AND (Half, ENUMERATE_TEST_ENTRY)
111
+ ET_FORALL_REALHBF16_TYPES ( ENUMERATE_TEST_ENTRY)
112
112
113
113
#undef ENUMERATE_TEST_ENTRY
114
114
}
115
+
116
+ template <ScalarType DTYPE>
117
+ void test_optimized_path_ignores_leading_1_dimensions () {
118
+ TensorFactory<DTYPE> tf;
119
+
120
+ const std::vector<int32_t > sizes1 = {1 , 1 , 2 , 2 };
121
+ const std::vector<int32_t > sizes2 = {1 , 2 , 2 };
122
+
123
+ // Destination for the mul.
124
+ Tensor out = tf.zeros (sizes1);
125
+
126
+ // Multiply two tensors
127
+ op_mul_out (
128
+ tf.make (sizes1, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }), tf.ones (sizes2), out);
129
+ EXPECT_TENSOR_CLOSE (out, tf.make (sizes1, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }));
130
+ }
131
+
132
+ template <ScalarType DTYPE>
133
+ void test_broadcast_a2b () {
134
+ TensorFactory<DTYPE> tf_a;
135
+
136
+ std::vector<std::vector<int32_t >> b_sizeses = {
137
+ {2 },
138
+ {1 , 2 },
139
+ };
140
+ for (const auto & b_sizes : b_sizeses) {
141
+ // a and b of different shapes
142
+ Tensor a = tf_a.make ({2 , 2 }, /* data=*/ {1 , 2 , 3 , 4 });
143
+ Tensor b = tf_a.make (b_sizes, /* data=*/ {2 , 2 });
144
+
145
+ // Destination for output of mul.
146
+ Tensor out = tf_a.zeros ({2 , 2 });
147
+
148
+ // Check that it matches the expected output.
149
+ EXPECT_TENSOR_CLOSE (
150
+ op_mul_out (a, b, out), tf_a.make ({2 , 2 }, /* data=*/ {2 , 4 , 6 , 8 }));
151
+ }
152
+ }
153
+
154
+ template <ScalarType DTYPE>
155
+ void test_broadcast_b2a () {
156
+ TensorFactory<DTYPE> tf_a;
157
+ // a and b of different shapes
158
+ Tensor a = tf_a.make ({2 }, /* data=*/ {2 , 2 });
159
+ Tensor b = tf_a.make ({2 , 2 }, /* data=*/ {1 , 2 , 3 , 4 });
160
+
161
+ // Destination for output of mul.
162
+ Tensor out = tf_a.zeros ({2 , 2 });
163
+
164
+ // Check that it matches the expected output.
165
+ EXPECT_TENSOR_CLOSE (
166
+ op_mul_out (a, b, out), tf_a.make ({2 , 2 }, /* data=*/ {2 , 4 , 6 , 8 }));
167
+ }
168
+
169
+ template <ScalarType DTYPE>
170
+ void test_scalar_input_broadcast () {
171
+ TensorFactory<DTYPE> tf_a;
172
+
173
+ // a is a 1d tensor and b is a scalar
174
+ Tensor a = tf_a.make ({2 }, /* data=*/ {2 , 2 });
175
+ Tensor b = tf_a.make ({}, /* data=*/ {2 });
176
+
177
+ // Destination for output of mul.
178
+ Tensor out = tf_a.make ({2 }, /* data=*/ {2 , 2 });
179
+ Tensor expected = tf_a.make ({2 }, /* data=*/ {4 , 4 });
180
+
181
+ // Check that it matches the expected output.
182
+ EXPECT_TENSOR_CLOSE (op_mul_out (a, b, out), expected);
183
+ EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
184
+ }
115
185
};
116
186
117
187
class OpMulScalarOutTest : public OperatorTest {
@@ -141,6 +211,14 @@ TEST_F(OpMulOutTest, DoubleTensors) {
141
211
test_floating_point_mul_out<ScalarType::Double>();
142
212
}
143
213
214
+ TEST_F (OpMulOutTest, HalfTensors) {
215
+ test_floating_point_mul_out<ScalarType::Half>();
216
+ }
217
+
218
+ TEST_F (OpMulOutTest, BFloat16Tensors) {
219
+ test_floating_point_mul_out<ScalarType::BFloat16>();
220
+ }
221
+
144
222
TEST_F (OpMulOutTest, BoolTensors) {
145
223
TensorFactory<ScalarType::Bool> tf;
146
224
@@ -166,18 +244,12 @@ TEST_F(OpMulOutTest, BoolTensors) {
166
244
}
167
245
168
246
TEST_F (OpMulOutTest, OptimizedPathIgnoresLeading1Dimensions) {
169
- TensorFactory<ScalarType::Float> tf;
247
+ #define ENUMERATE_TEST_ENTRY (ctype, dtype ) \
248
+ test_optimized_path_ignores_leading_1_dimensions<ScalarType::dtype>();
170
249
171
- const std::vector<int32_t > sizes1 = {1 , 1 , 2 , 2 };
172
- const std::vector<int32_t > sizes2 = {1 , 2 , 2 };
250
+ ET_FORALL_FLOATHBF16_TYPES (ENUMERATE_TEST_ENTRY);
173
251
174
- // Destination for the mul.
175
- Tensor out = tf.zeros (sizes1);
176
-
177
- // Multiply two tensors
178
- op_mul_out (
179
- tf.make (sizes1, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }), tf.ones (sizes2), out);
180
- EXPECT_TENSOR_CLOSE (out, tf.make (sizes1, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }));
252
+ #undef ENUMERATE_TEST_ENTRY
181
253
}
182
254
183
255
// Mismatched shape tests.
@@ -202,40 +274,16 @@ TEST_F(OpMulOutTest, MismatchedNonBroadcastableInputShapesDies) {
202
274
203
275
// Broadcast tensor b's size to tensor a's size
204
276
TEST_F (OpMulOutTest, BroadcastA2BTest) {
205
- TensorFactory<ScalarType::Int> tf_a;
206
-
207
- std::vector<std::vector<int32_t >> b_sizeses = {
208
- {2 },
209
- {1 , 2 },
210
- };
211
- for (const auto & b_sizes : b_sizeses) {
212
- // a and b of different shapes
213
- Tensor a = tf_a.make ({2 , 2 }, /* data=*/ {1 , 2 , 3 , 4 });
214
- Tensor b = tf_a.make (b_sizes, /* data=*/ {2 , 2 });
215
-
216
- // Destination for output of mul.
217
- Tensor out = tf_a.zeros ({2 , 2 });
218
-
219
- // Check that it matches the expected output.
220
- EXPECT_TENSOR_CLOSE (
221
- op_mul_out (a, b, out), tf_a.make ({2 , 2 }, /* data=*/ {2 , 4 , 6 , 8 }));
222
- }
277
+ test_broadcast_a2b<ScalarType::Int>();
278
+ test_broadcast_a2b<ScalarType::Half>();
279
+ test_broadcast_a2b<ScalarType::BFloat16>();
223
280
}
224
281
225
282
// Broadcast tensor a's size to tensor b's size
226
283
TEST_F (OpMulOutTest, BroadcastB2ATest) {
227
- TensorFactory<ScalarType::Int> tf_a;
228
-
229
- // a and b of different shapes
230
- Tensor a = tf_a.make ({2 }, /* data=*/ {2 , 2 });
231
- Tensor b = tf_a.make ({2 , 2 }, /* data=*/ {1 , 2 , 3 , 4 });
232
-
233
- // Destination for output of mul.
234
- Tensor out = tf_a.zeros ({2 , 2 });
235
-
236
- // Check that it matches the expected output.
237
- EXPECT_TENSOR_CLOSE (
238
- op_mul_out (a, b, out), tf_a.make ({2 , 2 }, /* data=*/ {2 , 4 , 6 , 8 }));
284
+ test_broadcast_b2a<ScalarType::Int>();
285
+ test_broadcast_b2a<ScalarType::Half>();
286
+ test_broadcast_b2a<ScalarType::BFloat16>();
239
287
}
240
288
241
289
// Broadcast tensor a and b's size to a new size c.
@@ -256,19 +304,9 @@ TEST_F(OpMulOutTest, BroadcastAB2CTest) {
256
304
}
257
305
258
306
TEST_F (OpMulOutTest, ScalarInputBroadcastTest) {
259
- TensorFactory<ScalarType::Int> tf_a;
260
-
261
- // a is a 1d tensor and b is a scalar
262
- Tensor a = tf_a.make ({2 }, /* data=*/ {2 , 2 });
263
- Tensor b = tf_a.make ({}, /* data=*/ {2 });
264
-
265
- // Destination for output of mul.
266
- Tensor out = tf_a.make ({2 }, /* data=*/ {2 , 2 });
267
- Tensor expected = tf_a.make ({2 }, /* data=*/ {4 , 4 });
268
-
269
- // Check that it matches the expected output.
270
- EXPECT_TENSOR_CLOSE (op_mul_out (a, b, out), expected);
271
- EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
307
+ test_scalar_input_broadcast<ScalarType::Int>();
308
+ test_scalar_input_broadcast<ScalarType::Half>();
309
+ test_scalar_input_broadcast<ScalarType::BFloat16>();
272
310
}
273
311
274
312
TEST_F (OpMulOutTest, MismatchedOutputShapesDies) {
0 commit comments