-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector][nfc] Refactor vector.contract matvec tests #72832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector][nfc] Refactor vector.contract matvec tests #72832
Conversation
Update tests in "vector-contract-matvec-transforms.mlir" so that they are consistent with similar tests in: * "vector-contract-to-outerproduct-transforms.mlir". This is to enable further refactoring in a follow-up patch, namely to: * remove duplication (this will be much easier once consistent naming is used), * extend tests in "vector-contract-matvec-transforms.mlir" with cases for scalable vectors, * merge "vector-contract-matvec-transforms.mlir" and "vector-contract-to-outerproduct-transforms.mlir" (there's no need for 2 different files testing identical transformations). Overview of changes in this patch: 1. Simplify the test by removing MemRef wrappers - this test verifies Vector -> Vector transformations and MemRefs are not needed. 2. Use (m, k) indices instead of (i, j). 3. Rename function names. This is part of a larger effort to improve test coverage for scalable vectors in the Vector dialect.
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesUpdate tests in "vector-contract-matvec-transforms.mlir" so that they
This is to enable further refactoring in a follow-up patch, namely to:
Overview of changes in this patch:
This is part of a larger effort to improve test coverage for scalable Full diff: https://github.com/llvm/llvm-project/pull/72832.diff 1 Files Affected:
diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index cfcb14a477b6b71..811fb589792b1a8 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
#matvec_accesses = [
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>
]
#matvec_trait = {
indexing_maps = #matvec_accesses,
@@ -16,9 +16,9 @@
}
#mattransvec_accesses = [
- affine_map<(i, j) -> (j, i)>,
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>
]
#mattransvec_trait = {
indexing_maps = #mattransvec_accesses,
@@ -26,9 +26,9 @@
}
#vecmat_accesses = [
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (m)>
]
#vecmat_trait = {
indexing_maps = #vecmat_accesses,
@@ -36,9 +36,9 @@
}
#vecmattrans_accesses = [
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (j, i)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (m)>
]
#vecmattrans_trait = {
indexing_maps = #vecmattrans_accesses,
@@ -46,166 +46,118 @@
}
#redpar_vecmattrans_accesses = [
- affine_map<(i, j) -> (i)>,
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (j)>
+ affine_map<(m, k) -> (m)>,
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (k)>
]
#redpar_vecmattrans_trait = {
indexing_maps = #redpar_vecmattrans_accesses,
iterator_types = ["reduction", "parallel"]
}
-// CHECK-LABEL: func @matvec2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @matvecmax2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_mk_k_m_max
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @mattransvec2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @vecmat2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_k_mk_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @vecmattrans2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_k_km_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_k_km_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @redpar_vecmattrans2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_m_mk_k
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
module attributes {transform.with_named_sequence} {
|
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesUpdate tests in "vector-contract-matvec-transforms.mlir" so that they
This is to enable further refactoring in a follow-up patch, namely to:
Overview of changes in this patch:
This is part of a larger effort to improve test coverage for scalable Full diff: https://github.com/llvm/llvm-project/pull/72832.diff 1 Files Affected:
diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index cfcb14a477b6b71..811fb589792b1a8 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
#matvec_accesses = [
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>
]
#matvec_trait = {
indexing_maps = #matvec_accesses,
@@ -16,9 +16,9 @@
}
#mattransvec_accesses = [
- affine_map<(i, j) -> (j, i)>,
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>
]
#mattransvec_trait = {
indexing_maps = #mattransvec_accesses,
@@ -26,9 +26,9 @@
}
#vecmat_accesses = [
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (m)>
]
#vecmat_trait = {
indexing_maps = #vecmat_accesses,
@@ -36,9 +36,9 @@
}
#vecmattrans_accesses = [
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (j, i)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (m)>
]
#vecmattrans_trait = {
indexing_maps = #vecmattrans_accesses,
@@ -46,166 +46,118 @@
}
#redpar_vecmattrans_accesses = [
- affine_map<(i, j) -> (i)>,
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (j)>
+ affine_map<(m, k) -> (m)>,
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (k)>
]
#redpar_vecmattrans_trait = {
indexing_maps = #redpar_vecmattrans_accesses,
iterator_types = ["reduction", "parallel"]
}
-// CHECK-LABEL: func @matvec2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @matvecmax2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_mk_k_m_max
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @mattransvec2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @vecmat2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_k_mk_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @vecmattrans2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_k_km_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_k_km_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @redpar_vecmattrans2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_m_mk_k
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
module attributes {transform.with_named_sequence} {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I don't spot anything off :)
Update tests in "vector-contract-matvec-transforms.mlir" so that they
are consistent with similar tests in:
This is to enable further refactoring in a follow-up patch, namely to:
is used),
for scalable vectors,
"vector-contract-to-outerproduct-transforms.mlir" (there's no need
for 2 different files testing identical transformations).
Overview of changes in this patch:
Vector -> Vector transformations and MemRefs are not needed.
This is part of a larger effort to improve test coverage for scalable
vectors in the Vector dialect. Implements #72834.