Skip to content

Commit eb1d506

Browse files
authored
[MLIR][Vector] Refactor tests for contract -> OP transforms (2/N) (#73445)
This is a direct follow-up of #73348. The matvec trait that's used for `@matvec_m_mk_k` was incorrectly updated from: ``` #redpar_vecmattrans_accesses = [ affine_map<(m, k) -> (m)>, affine_map<(m, k) -> (m, k)>, affine_map<(m, k) -> (k)> ] indexing_maps = #redpar_vecmattrans_accesses, iterator_types = ["reduction", "parallel"] } ``` to: ``` #matvec_accesses_4 = [ affine_map<(m, k) -> (k)>, affine_map<(m, k) -> (k, m)>, affine_map<(m, k) -> (m)> ] indexing_maps = #matvec_accesses_4, iterator_types = ["parallel", "reduction"] } ``` Note that these traits describe identical matvec operation, hence the `CHECK` lines are identical for both. Also, `#redpar_vecmattrans_trait` is identical to `#matvec_trait_8` that's already present in: * "vector-contract-to-outerproduct-matvec-transforms.mlir" For this reason: * `@matvec_m_mk_k` is moved near other tests that already use `#matvec_trait_8`, * `#redpar_vecmattrans_trait` is replaced `#matvec_trait_8`. This is a part of a larger effort to add cases with scalable vectors to tests for the Vector dialect. I am refactoring these tests so that it's easier to identify what cases are tested and where to add tests for scalable vectors. Implements #72834.
1 parent 104b7c6 commit eb1d506

File tree

1 file changed

+22
-32
lines changed

1 file changed

+22
-32
lines changed

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@
4646
iterator_types = ["parallel", "reduction"]
4747
}
4848

49-
#redpar_vecmattrans_accesses = [
50-
affine_map<(m, k) -> (m)>,
51-
affine_map<(m, k) -> (m, k)>,
52-
affine_map<(m, k) -> (k)>
49+
#matvec_accesses_8 = [
50+
affine_map<(k, m) -> (k)>,
51+
affine_map<(k, m) -> (k, m)>,
52+
affine_map<(k, m) -> (m)>
5353
]
54-
#redpar_vecmattrans_trait = {
55-
indexing_maps = #redpar_vecmattrans_accesses,
54+
#matvec_trait_8 = {
55+
indexing_maps = #matvec_accesses_8,
5656
iterator_types = ["reduction", "parallel"]
5757
}
5858

@@ -321,23 +321,6 @@ func.func @matvec_k_km_m(%A: vector<2x2xf32>,
321321
return %0 : vector<2xf32>
322322
}
323323

324-
// CHECK-LABEL: func @matvec_m_mk_k
325-
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
326-
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
327-
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
328-
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
329-
// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
330-
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
331-
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
332-
// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
333-
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
334-
func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
335-
%x: vector<2xf32>,
336-
%b: vector<2xf32>) -> vector<2xf32> {
337-
%0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
338-
return %0 : vector<2xf32>
339-
}
340-
341324
// ============================================================================
342325
// Matvec 5 (masked + scalable)
343326
// ============================================================================
@@ -474,16 +457,23 @@ func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
474457
}
475458

476459
// ============================================================================
477-
// Matvec 8 (masked + scalable)
460+
// Matvec 8 (plain + masked + scalable)
478461
// ============================================================================
479-
#matvec_accesses_8 = [
480-
affine_map<(k, m) -> (k)>,
481-
affine_map<(k, m) -> (k, m)>,
482-
affine_map<(k, m) -> (m)>
483-
]
484-
#matvec_trait_8 = {
485-
indexing_maps = #matvec_accesses_8,
486-
iterator_types = ["reduction", "parallel"]
462+
// CHECK-LABEL: func @matvec_m_mk_k
463+
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
464+
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
465+
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
466+
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
467+
// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
468+
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
469+
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
470+
// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
471+
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
472+
func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
473+
%x: vector<2xf32>,
474+
%b: vector<2xf32>) -> vector<2xf32> {
475+
%0 = vector.contract #matvec_trait_8 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
476+
return %0 : vector<2xf32>
487477
}
488478

489479
// CHECK-LABEL: @masked_tmatvec_k_km_m

0 commit comments

Comments
 (0)