Skip to content

Commit 07de2f1

Browse files
committed
fixup! [mlir][vector] Extend CreateMaskFolder (llvm#75842)
Re-order matvec tests so that the one without masking is always first
1 parent c9e73e5 commit 07de2f1

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,23 @@ func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>,
235235
// ============================================================================
236236
// Matvec 2 (plain + masked + scalable)
237237
// ============================================================================
238+
// CHECK-LABEL: func @matvec_km_k_m
239+
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
240+
// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
241+
// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
242+
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
243+
// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
244+
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
245+
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
246+
// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
247+
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
248+
func.func @matvec_km_k_m(%A: vector<2x2xf32>,
249+
%x: vector<2xf32>,
250+
%b: vector<2xf32>) -> vector<2xf32> {
251+
%0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
252+
return %0 : vector<2xf32>
253+
}
254+
238255
// CHECK-LABEL: @masked_matvec_km_k_m
239256
// CHECK-SAME: %[[A:.+]]: vector<2x4xf32>
240257
// CHECK-SAME: %[[X:.+]]: vector<2xf32>
@@ -273,26 +290,27 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
273290
return %res : vector<[4]xf32>
274291
}
275292

276-
// CHECK-LABEL: func @matvec_km_k_m
293+
// ============================================================================
294+
// Matvec 3 (plain + masked + scalable)
295+
// ============================================================================
296+
// CHECK-LABEL: func @matvec_k_mk_m
277297
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
278298
// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
279299
// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
280-
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
281-
// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
282-
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
283-
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
284-
// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
285-
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
286-
func.func @matvec_km_k_m(%A: vector<2x2xf32>,
300+
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
301+
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
302+
// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
303+
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
304+
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
305+
// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
306+
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
307+
func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
287308
%x: vector<2xf32>,
288309
%b: vector<2xf32>) -> vector<2xf32> {
289-
%0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
310+
%0 = vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
290311
return %0 : vector<2xf32>
291312
}
292313

293-
// ============================================================================
294-
// Matvec 3 (plain + masked + scalable)
295-
// ============================================================================
296314
// CHECK-LABEL: @masked_matvec_k_mk_m
297315
// CHECK-SAME: %[[A:.+]]: vector<4x2xf32>
298316
// CHECK-SAME: %[[X:.+]]: vector<2xf32>
@@ -331,24 +349,6 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
331349
return %res : vector<[4]xf32>
332350
}
333351

334-
// CHECK-LABEL: func @matvec_k_mk_m
335-
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
336-
// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
337-
// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
338-
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
339-
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
340-
// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
341-
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
342-
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
343-
// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
344-
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
345-
func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
346-
%x: vector<2xf32>,
347-
%b: vector<2xf32>) -> vector<2xf32> {
348-
%0 = vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
349-
return %0 : vector<2xf32>
350-
}
351-
352352
// ============================================================================
353353
// Matvec 4 (plain + masked + scalable)
354354
// ============================================================================

0 commit comments

Comments
 (0)