Skip to content

Commit be6d364

Browse files
committed
fixup! [mlir][vector] Update v.contract -> v.outerproduct tests
Remove duplicate test
1 parent 834fa55 commit be6d364

File tree

1 file changed

+15
-57
lines changed

1 file changed

+15
-57
lines changed

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir

Lines changed: 15 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -406,17 +406,17 @@ func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>, %arg1: vector<1x3xf32>, %
406406
iterator_types = ["parallel", "parallel", "reduction"]
407407
}
408408

409-
#matvec_accesses_0 = [
409+
#matvec_accesses_1 = [
410410
affine_map<(m, k) -> (m, k)>,
411411
affine_map<(m, k) -> (k)>,
412412
affine_map<(m, k) -> (m)>
413413
]
414-
#matvec_trait_0 = {
415-
indexing_maps = #matvec_accesses_0,
414+
#matvec_trait_1 = {
415+
indexing_maps = #matvec_accesses_1,
416416
iterator_types = ["parallel", "reduction"]
417417
}
418418

419-
// CHECK-LABEL: func.func @masked_extract_contract2(
419+
// CHECK-LABEL: func.func @masked_matvec_mk_k_m(
420420
// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
421421
// CHECK-SAME: %{{.*}}: vector<3xf32>,
422422
// CHECK-SAME: %{{.*}}: vector<2xf32>,
@@ -431,17 +431,17 @@ func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>, %arg1: vector<1x3xf32>, %
431431
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
432432
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
433433

434-
func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
435-
%arg1: vector<3xf32>,
436-
%arg2: vector<2xf32>,
437-
%m: vector<2x3xi1>) -> vector<2xf32> {
438-
%0 = vector.mask %m { vector.contract #matvec_trait_0 %arg0, %arg1, %arg2
434+
func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
435+
%arg1: vector<3xf32>,
436+
%arg2: vector<2xf32>,
437+
%m: vector<2x3xi1>) -> vector<2xf32> {
438+
%0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
439439
: vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
440440
return %0 : vector<2xf32>
441441
}
442442

443443

444-
// CHECK-LABEL: func.func @masked_extract_contract2_scalable_parallel_dim(
444+
// CHECK-LABEL: func.func @masked_matvec_mk_k_m_scalable_parallel_dim(
445445
// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
446446
// CHECK-SAME: %{{.*}}: vector<3xf32>,
447447
// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
@@ -455,57 +455,15 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
455455

456456
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
457457
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
458-
func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
459-
%arg1: vector<3xf32>,
460-
%arg2: vector<[2]xf32>,
461-
%m: vector<[2]x3xi1>) -> vector<[2]xf32> {
462-
%0 = vector.mask %m { vector.contract #matvec_trait_0 %arg0, %arg1, %arg2
458+
func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
459+
%arg1: vector<3xf32>,
460+
%arg2: vector<[2]xf32>,
461+
%m: vector<[2]x3xi1>) -> vector<[2]xf32> {
462+
%0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
463463
: vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
464464
return %0 : vector<[2]xf32>
465465
}
466466

467-
#matvec_accesses_1 = [
468-
affine_map<(m, k) -> (m, k)>,
469-
affine_map<(m, k) -> (k)>,
470-
affine_map<(m, k) -> (m)>
471-
]
472-
#matvec_trait_1 = {
473-
indexing_maps = #matvec_accesses_1,
474-
iterator_types = ["parallel", "reduction"]
475-
}
476-
477-
// CHECK-LABEL: @masked_matvec_mk_k_m
478-
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
479-
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
480-
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
481-
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
482-
func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
483-
// CHECK: vector.transpose %[[MASK]]
484-
// CHECK: vector.transpose %[[MAT]]
485-
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
486-
%res = vector.mask %mask {
487-
vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
488-
: vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
489-
} : vector<4x2xi1> -> vector<4xf32>
490-
return %res : vector<4xf32>
491-
}
492-
493-
// CHECK-LABEL: @masked_matvec_mk_k_m_scalable_parallel_dim
494-
// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
495-
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
496-
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
497-
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
498-
func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
499-
// CHECK: vector.transpose %[[MASK]]
500-
// CHECK: vector.transpose %[[MAT]]
501-
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
502-
%res = vector.mask %mask {
503-
vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
504-
: vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
505-
} : vector<[4]x2xi1> -> vector<[4]xf32>
506-
return %res : vector<[4]xf32>
507-
}
508-
509467
#matvec_accesses_2 = [
510468
affine_map<(m, k) -> (k, m)>,
511469
affine_map<(m, k) -> (k)>,

0 commit comments

Comments
 (0)