-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Vector] Refactor tests for contract -> OP transforms #73348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Vector] Refactor tests for contract -> OP transforms #73348
Conversation
This patch refactors tests for: vector.contract -> vector.outerproduct transformations for matvec operations (b += Ax). Specifically, relevant tests from the following 2 files: vector-contract-matvec-transforms.mlir vector-contract-to-outerproduct-transforms.mlir are combined into one: vector-contract-to-outerproduct-matvec-transforms.mlir All original tests are preserved and no new tests are added. Implements llvm#72834.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis patch refactors tests for: vector.contract -> vector.outerproduct transformations for matvec operations (b += Ax). Specifically, relevant vector-contract-matvec-transforms.mlir are combined into one: vector-contract-to-outerproduct-matvec-transforms.mlir All original tests are preserved and no new tests are added. Implements #72834. Patch is 43.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73348.diff 2 Files Affected:
diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index 811fb589792b1a8..3ca3d344c1abe04 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -1,47 +1,48 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
-#matvec_accesses = [
+#matvec_accesses_1 = [
affine_map<(m, k) -> (m, k)>,
affine_map<(m, k) -> (k)>,
affine_map<(m, k) -> (m)>
]
-#matvec_trait = {
- indexing_maps = #matvec_accesses,
+#matvec_trait_1 = {
+ indexing_maps = #matvec_accesses_1,
iterator_types = ["parallel", "reduction"]
}
+
#matvecmax_trait = {
- indexing_maps = #matvec_accesses,
+ indexing_maps = #matvec_accesses_1,
iterator_types = ["parallel", "reduction"],
kind = #vector.kind<maxf>
}
-#mattransvec_accesses = [
+#matvec_accesses_2 = [
affine_map<(m, k) -> (k, m)>,
affine_map<(m, k) -> (k)>,
affine_map<(m, k) -> (m)>
]
-#mattransvec_trait = {
- indexing_maps = #mattransvec_accesses,
+#matvec_trait_2 = {
+ indexing_maps = #matvec_accesses_2,
iterator_types = ["parallel", "reduction"]
}
-#vecmat_accesses = [
+#matvec_accesses_3 = [
affine_map<(m, k) -> (k)>,
affine_map<(m, k) -> (m, k)>,
affine_map<(m, k) -> (m)>
]
-#vecmat_trait = {
- indexing_maps = #vecmat_accesses,
+#matvec_trait_3 = {
+ indexing_maps = #matvec_accesses_3,
iterator_types = ["parallel", "reduction"]
}
-#vecmattrans_accesses = [
+#matvec_accesses_4 = [
affine_map<(m, k) -> (k)>,
affine_map<(m, k) -> (k, m)>,
affine_map<(m, k) -> (m)>
]
-#vecmattrans_trait = {
- indexing_maps = #vecmattrans_accesses,
+#matvec_trait_4 = {
+ indexing_maps = #matvec_accesses_4,
iterator_types = ["parallel", "reduction"]
}
@@ -55,6 +56,56 @@
iterator_types = ["reduction", "parallel"]
}
+// ============================================================================
+// Matvec 1 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: func.func @masked_matvec_mk_k_m(
+// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<2xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<2xf32>,
+ %m: vector<2x3xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @masked_matvec_mk_k_m_scalable_parallel_dim(
+// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<[2]xf32>,
+ %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+
// CHECK-LABEL: func @matvec_mk_k_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
@@ -69,10 +120,13 @@
func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
%x: vector<2xf32>,
%b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+ %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
+// ============================================================================
+// Matvec 1 - max (plain)
+// ============================================================================
// CHECK-LABEL: func @matvec_mk_k_m_max
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
@@ -91,6 +145,47 @@ func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
return %0 : vector<2xf32>
}
+// ============================================================================
+// Matvec 2 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: @masked_matvec_km_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<4xf32>,
+ %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+ : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<[4]xf32>,
+ %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+ : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<[4]x2xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
// CHECK-LABEL: func @matvec_km_k_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
@@ -104,10 +199,52 @@ func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
func.func @matvec_km_k_m(%A: vector<2x2xf32>,
%x: vector<2xf32>,
%b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+ %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
+// ============================================================================
+// Matvec 3 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: @masked_matvec_k_mk_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<4xf32>,
+ %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<[4]xf32>,
+ %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<[4]x2xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+
// CHECK-LABEL: func @matvec_k_mk_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
@@ -122,10 +259,51 @@ func.func @matvec_km_k_m(%A: vector<2x2xf32>,
func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
%x: vector<2xf32>,
%b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+ %0 = vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
+// ============================================================================
+// Matvec 4 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<[4]xf32>,
+ %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<[4]x2xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<4xf32>,
+ %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
// CHECK-LABEL: func @matvec_k_km_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
@@ -139,7 +317,7 @@ func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
func.func @matvec_k_km_m(%A: vector<2x2xf32>,
%x: vector<2xf32>,
%b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+ %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
@@ -156,10 +334,193 @@ func.func @matvec_k_km_m(%A: vector<2x2xf32>,
func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
%x: vector<2xf32>,
%b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+ %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
+// ============================================================================
+// Matvec 5 (masked + scalable)
+// ============================================================================
+#matvec_accesses_5 = [
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+ indexing_maps = #matvec_accesses_5,
+ iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-LABEL: @masked_tmatvec_mk_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+ : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+ : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<2x[4]xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+// ============================================================================
+// Matvec 6 (masked + scalable)
+// ============================================================================
+#matvec_accesses_6 = [
+ affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+ indexing_maps = #matvec_accesses_6,
+ iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-LABEL: @masked_tmatvec_km_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+ : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+ : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<2x[4]xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+// ============================================================================
+// Matvec 7 (masked + scalable)
+// ============================================================================
+#matvec_accesses_7 = [
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+ indexing_maps = #matvec_accesses_7,
+ iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_mk_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK: ...
[truncated]
|
Rename the test file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked these changes with https://meldmerge.org/, and didn't see anything off, so this LGTM 👍
This is a direct follow-up of llvm#73348. The matvec trait that's used for `@matvec_m_mk_k` was incorrectly updated from: ``` affine_map<(m, k) -> (m)>, affine_map<(m, k) -> (m, k)>, affine_map<(m, k) -> (k)> ] indexing_maps = #redpar_vecmattrans_accesses, iterator_types = ["reduction", "parallel"] } ``` to: ``` affine_map<(m, k) -> (k)>, affine_map<(m, k) -> (k, m)>, affine_map<(m, k) -> (m)> ] indexing_maps = #matvec_accesses_4, iterator_types = ["parallel", "reduction"] } ``` Note that these traits describe identical matvec operation, hence the `CHECK` lines are identical for both. Also, `#redpar_vecmattrans_trait` is identical to `#matvec_trait_8` that's already present in: * "vector-contract-to-outerproduct-matvec-transforms.mlir" For this reason: * `@matvec_m_mk_k` is moved near other tests that already use `#matvec_trait_8`, * `#redpar_vecmattrans_trait` is replaced `#matvec_trait_8`. This is a part of a larger effort to add cases with scalable vectors to tests for the Vector dialect. I am refactoring these tests so that it's easier to identify what cases are tested and where to add tests for scalable vectors. Implements llvm#72834.
…3445) This is a direct follow-up of #73348. The matvec trait that's used for `@matvec_m_mk_k` was incorrectly updated from: ``` #redpar_vecmattrans_accesses = [ affine_map<(m, k) -> (m)>, affine_map<(m, k) -> (m, k)>, affine_map<(m, k) -> (k)> ] indexing_maps = #redpar_vecmattrans_accesses, iterator_types = ["reduction", "parallel"] } ``` to: ``` #matvec_accesses_4 = [ affine_map<(m, k) -> (k)>, affine_map<(m, k) -> (k, m)>, affine_map<(m, k) -> (m)> ] indexing_maps = #matvec_accesses_4, iterator_types = ["parallel", "reduction"] } ``` Note that these traits describe identical matvec operation, hence the `CHECK` lines are identical for both. Also, `#redpar_vecmattrans_trait` is identical to `#matvec_trait_8` that's already present in: * "vector-contract-to-outerproduct-matvec-transforms.mlir" For this reason: * `@matvec_m_mk_k` is moved near other tests that already use `#matvec_trait_8`, * `#redpar_vecmattrans_trait` is replaced `#matvec_trait_8`. This is a part of a larger effort to add cases with scalable vectors to tests for the Vector dialect. I am refactoring these tests so that it's easier to identify what cases are tested and where to add tests for scalable vectors. Implements #72834.
This patch refactors tests for:
transformations for matvec operations (b += Ax). Specifically, relevant
tests from the following 2 files:
are combined into one:
All original tests are preserved and no new tests are added.
This is a part of a larger effort to add cases with scalable vectors
to tests for the Vector dialect. I am refactoring these test as a
preparation for follow-up patches.
Implements #72834.