-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Vector] Refactor tests for contract -> OP transforms #73217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch refactors tests for : `vector.contract` -> `vector.outerproduct` for matvec operations (b += Ax). Relevant tests from the following 2 files: * vector-contract-matvec-transforms.mlir * vector-contract-to-outerproduct-transforms.mlir are combined into one: * vector-contract-to-outerproduct-matvec-transforms.mlir Summary of changes (on top of moving things between files): * duplicate tests are removed, * missing cases for scalable vectors are added, * names of LIT variables and operation traits are unified. Implements llvm#72834.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis patch refactors tests for :
for matvec operations (b += Ax). Relevant tests from the following 2
are combined into one:
Summary of changes (on top of moving things between files):
Implements #72834. Patch is 60.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73217.diff 3 Files Affected:
diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
deleted file mode 100644
index 811fb589792b1a8..000000000000000
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ /dev/null
@@ -1,170 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
-
-#matvec_accesses = [
- affine_map<(m, k) -> (m, k)>,
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (m)>
-]
-#matvec_trait = {
- indexing_maps = #matvec_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-#matvecmax_trait = {
- indexing_maps = #matvec_accesses,
- iterator_types = ["parallel", "reduction"],
- kind = #vector.kind<maxf>
-}
-
-#mattransvec_accesses = [
- affine_map<(m, k) -> (k, m)>,
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (m)>
-]
-#mattransvec_trait = {
- indexing_maps = #mattransvec_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-
-#vecmat_accesses = [
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (m, k)>,
- affine_map<(m, k) -> (m)>
-]
-#vecmat_trait = {
- indexing_maps = #vecmat_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-
-#vecmattrans_accesses = [
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (k, m)>,
- affine_map<(m, k) -> (m)>
-]
-#vecmattrans_trait = {
- indexing_maps = #vecmattrans_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-
-#redpar_vecmattrans_accesses = [
- affine_map<(m, k) -> (m)>,
- affine_map<(m, k) -> (m, k)>,
- affine_map<(m, k) -> (k)>
-]
-#redpar_vecmattrans_trait = {
- indexing_maps = #redpar_vecmattrans_accesses,
- iterator_types = ["reduction", "parallel"]
-}
-
-// CHECK-LABEL: func @matvec_mk_k_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_mk_k_m_max
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
-func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_km_k_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_km_k_m(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_k_mk_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_k_km_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_k_km_m(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_m_mk_k
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
- transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- } : !transform.op<"func.func">
- transform.yield
- }
-}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
new file mode 100644
index 000000000000000..9030cc6b1ace4e0
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -0,0 +1,610 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
+/// Matvec operations:
+/// b += A * x.
+/// (b and x are 1-d vectors, A is a 2-d matrix). ATM three different variants
+/// are tested:
+/// * plain (no mask, fixed-wdith vectors),
+/// * masked (fixed-width vectors,
+/// * scalable (mask + scalable vectors).
+///
+/// TODO: These tests were extracted from 2 different files. If you find the
+/// formatting inconsistent, please update accordingly.
+
+#matvec_accesses_1 = [
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>
+]
+#matvec_trait_1 = {
+ indexing_maps = #matvec_accesses_1,
+ iterator_types = ["parallel", "reduction"]
+}
+
+#matvecmax_trait = {
+ indexing_maps = #matvec_accesses_1,
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<maxf>
+}
+
+#matvec_accesses_2 = [
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>
+]
+#matvec_trait_2 = {
+ indexing_maps = #matvec_accesses_2,
+ iterator_types = ["parallel", "reduction"]
+}
+
+#matvec_accesses_3 = [
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (m)>
+]
+#matvec_trait_3 = {
+ indexing_maps = #matvec_accesses_3,
+ iterator_types = ["parallel", "reduction"]
+}
+
+#matvec_accesses_4 = [
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (m)>
+]
+#matvec_trait_4 = {
+ indexing_maps = #matvec_accesses_4,
+ iterator_types = ["parallel", "reduction"]
+}
+
+#matvec_accesses_5 = [
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+ indexing_maps = #matvec_accesses_5,
+ iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_6 = [
+ affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+ indexing_maps = #matvec_accesses_6,
+ iterator_types = ["reduction", "parallel"]
+}
+
+// TODO: No mask
+#matvec_accesses_7 = [
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+ indexing_maps = #matvec_accesses_7,
+ iterator_types = ["reduction", "parallel"]
+}
+
+// TODO: No mask
+#matvec_accesses_8 = [
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_8 = {
+ indexing_maps = #matvec_accesses_8,
+ iterator_types = ["reduction", "parallel"]
+}
+
+// ============================================================================
+// Matvec 1 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @masked_matvec_mk_k_m(
+// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<2xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+func.func @masked_matvec_mk_k_m(%A: vector<2x3xf32>,
+ %x: vector<3xf32>,
+ %b: vector<2xf32>,
+ %m: vector<2x3xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
+ : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @masked_matvec_mk_k_m_scalable_parallel_dim(
+// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%A: vector<[2]x3xf32>,
+ %x: vector<3xf32>,
+ %b: vector<[2]xf32>,
+ %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
+ : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+
+// ============================================================================
+// Matvec 1 - max (plain)
+// ============================================================================
+// CHECK-LABEL: func @matvec_mk_k_m_max
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// ============================================================================
+// Matvec 2 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: @masked_matvec_km_k_m
+// CHECK-SAME: %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_km_k_m(%A: vector<2x4xf32>,
+ %x: vector<2xf32>,
+ %b: vector<4xf32>,
+ %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[A]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_2 %A, %x, %b
+ : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+ %x: vector<2xf32>,
+ %b: vector<[4]xf32>,
+ %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[A]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_2 %A, %x, %b
+ : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<[4]x2xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> i...
[truncated]
|
To make review easier in future think it'd be nice if each of:
Were separate commits in the PR, even though the PR will be squashed into a single commit :) |
Haha, that was half-expected :) Let me do it as separate PRs - I am yet to get used to having separate commits as meaningful chunks for work (sorry, old habits). Abandoning in favour of #73348 and follow-ups therein. |
This patch refactors tests for :
vector.contract
->vector.outerproduct
for matvec operations (b += Ax). Relevant tests from the following 2
files:
are combined into one:
Summary of changes (on top of moving things between files):
Implements #72834.