-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Vector] Refactor tests for contract -> OP transforms (3/N) #73447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Vector] Refactor tests for contract -> OP transforms (3/N) #73447
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis is a direct follow-up of #73348. The matvec trait that's used for
to:
Note that these traits describe identical matvec operation, hence the Also,
For this reason:
This is a part of a larger effort to add cases with scalable vectors to Implements #72834. Patch is 39.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73447.diff 1 Files Affected:
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index 3ca3d344c1abe04..e84a43feaff39dc 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -1,5 +1,17 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
+/// Matvec operations:
+/// b += A * x.
+/// (b and x are 1-d vectors, A is a 2-d matrix). ATM three different variants
+/// are tested:
+/// * plain (no mask, fixed-wdith vectors),
+/// * masked (fixed-width vectors,
+/// * scalable (mask + scalable vectors).
+///
+/// TODO: These tests were extracted from 2 different files. If you find the
+/// formatting inconsistent, please update accordingly.
+
#matvec_accesses_1 = [
affine_map<(m, k) -> (m, k)>,
affine_map<(m, k) -> (k)>,
@@ -46,19 +58,67 @@
iterator_types = ["parallel", "reduction"]
}
-#redpar_vecmattrans_accesses = [
- affine_map<(m, k) -> (m)>,
- affine_map<(m, k) -> (m, k)>,
- affine_map<(m, k) -> (k)>
+#matvec_accesses_5 = [
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+ indexing_maps = #matvec_accesses_5,
+ iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_6 = [
+ affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+ indexing_maps = #matvec_accesses_6,
+ iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_7 = [
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+ indexing_maps = #matvec_accesses_7,
+ iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_8 = [
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (m)>
]
-#redpar_vecmattrans_trait = {
- indexing_maps = #redpar_vecmattrans_accesses,
+#matvec_trait_8 = {
+ indexing_maps = #matvec_accesses_8,
iterator_types = ["reduction", "parallel"]
}
// ============================================================================
// Matvec 1 (plain + masked + scalable)
// ============================================================================
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
// CHECK-LABEL: func.func @masked_matvec_mk_k_m(
// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
// CHECK-SAME: %{{.*}}: vector<3xf32>,
@@ -73,12 +133,11 @@
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
-
-func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
- %arg1: vector<3xf32>,
- %arg2: vector<2xf32>,
+func.func @masked_matvec_mk_k_m(%A: vector<2x3xf32>,
+ %x: vector<3xf32>,
+ %b: vector<2xf32>,
%m: vector<2x3xi1>) -> vector<2xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
: vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
return %0 : vector<2xf32>
}
@@ -97,46 +156,28 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
-func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
- %arg1: vector<3xf32>,
- %arg2: vector<[2]xf32>,
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%A: vector<[2]x3xf32>,
+ %x: vector<3xf32>,
+ %b: vector<[2]xf32>,
%m: vector<[2]x3xi1>) -> vector<[2]xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
: vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
return %0 : vector<[2]xf32>
}
-// CHECK-LABEL: func @matvec_mk_k_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
// ============================================================================
// Matvec 1 - max (plain)
// ============================================================================
// CHECK-LABEL: func @matvec_mk_k_m_max
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
%x: vector<2xf32>,
@@ -149,38 +190,38 @@ func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
// Matvec 2 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL: @masked_matvec_km_k_m
-// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<4xf32>,
+func.func @masked_matvec_km_k_m(%A: vector<2x4xf32>,
+ %x: vector<2xf32>,
+ %b: vector<4xf32>,
%mask: vector<4x2xi1>) -> vector<4xf32> {
// CHECK: vector.transpose %[[MASK]]
- // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
- vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+ vector.contract #matvec_trait_2 %A, %x, %b
: vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
} : vector<4x2xi1> -> vector<4xf32>
return %res : vector<4xf32>
}
// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<[4]xf32>,
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+ %x: vector<2xf32>,
+ %b: vector<[4]xf32>,
%mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
// CHECK: vector.transpose %[[MASK]]
- // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
%res = vector.mask %mask {
- vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+ vector.contract #matvec_trait_2 %A, %x, %b
: vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
} : vector<[4]x2xi1> -> vector<[4]xf32>
return %res : vector<[4]xf32>
@@ -188,13 +229,13 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
// CHECK-LABEL: func @matvec_km_k_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @matvec_km_k_m(%A: vector<2x2xf32>,
%x: vector<2xf32>,
@@ -207,54 +248,53 @@ func.func @matvec_km_k_m(%A: vector<2x2xf32>,
// Matvec 3 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL: @masked_matvec_k_mk_m
-// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[A:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<4xf32>,
+func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<4xf32>,
%mask: vector<4x2xi1>) -> vector<4xf32> {
// CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
- vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+ vector.contract #matvec_trait_3 %x, %A, %b
: vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
} : vector<4x2xi1> -> vector<4xf32>
return %res : vector<4xf32>
}
// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[A:.+]]: vector<[4]x2xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<[4]xf32>,
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<[4]xf32>,
%mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
// CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
%res = vector.mask %mask {
- vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+ vector.contract #matvec_trait_3 %x, %A, %b
: vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
} : vector<[4]x2xi1> -> vector<[4]xf32>
return %res : vector<[4]xf32>
}
-
// CHECK-LABEL: func @matvec_k_mk_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
%x: vector<2xf32>,
@@ -266,253 +306,290 @@ func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
// ============================================================================
// Matvec 4 (plain + masked + scalable)
// ============================================================================
+// CHECK-LABEL: func @matvec_k_km_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_k_km_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<[4]xf32>,
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+ %x: vector<2xf32>,
+ %b: vector<[4]xf32>,
%mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
// CHECK: vector.transpose %[[MASK]]
- // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
%res = vector.mask %mask {
- vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+ vector.contract #matvec_trait_4 %x, %A, %b
: vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
} : vector<[4]x2xi1> -> vector<[4]xf32>
return %res : vector<[4]xf32>
}
// CHECK-LABEL: @masked_matvec_k_km_m
-// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<4xf32>,
+func.func @masked_matvec_k_km_m(%A: vector<2x4xf32>,
+ %x: vector<2xf32>,
+ ...
[truncated]
|
Depends on #73445 - only review the most recent commit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've had a look through and it all looks good to me (though it's a little tricky to diff these changes :)).
This patch refactors tests for: vector.contract -> vector.outerproduct for matvec operations (b += Ax). Summary of changes: * names of LIT variables are unified, * "plain" tests (i.e. without masking and with fixed-width vectors) are moved to the top of their respective sections, * missing "plain" cases are added. This is a part of a larger effort to add cases with scalable vectors to tests for the Vector dialect. I am refactoring these tests so that it's easier to identify what cases are tested and where to add tests for scalable vectors. Implements llvm#72834.
2854e83
to
ddd89c3
Compare
This patch refactors tests for:
for matvec operations (b += Ax). Summary of changes:
are moved to the top of their respective sections,
This is a part of a larger effort to add cases with scalable vectors to
tests for the Vector dialect. I am refactoring these tests so that it's
easier to identify what cases are tested and where to add tests for
scalable vectors.
Implements #72834.