Skip to content

Commit 8c329f4

Browse files
committed
*Added tests to check DropUnitDim transform is not being applied on contraction Op having user defined indexing_maps.
1 parent f696116 commit 8c329f4

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,23 @@ func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memr
3535

3636
// -----
3737

38+
func.func @negative_singleton_batch_matmul_to_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
39+
// CHECK-LABEL: @negative_singleton_batch_matmul_to_matmul_memref
40+
// CHECK-NOT: collapse_shape
41+
// CHECK-NOT: linalg.matmul
42+
// CHECK-NOT: expand_shape
43+
linalg.batch_matmul indexing_maps = [
44+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
45+
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
46+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
47+
]
48+
ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
49+
outs(%arg2 : memref<1x?x?xf32>)
50+
return
51+
}
52+
53+
// -----
54+
3855
func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512xf32>, %arg2: tensor<1x128xf32>) -> tensor<1x128xf32> {
3956
// CHECK-LABEL: @singleton_batch_matvec
4057
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32>
@@ -135,6 +152,20 @@ func.func @matmul_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x1xf32>, %arg
135152

136153
// -----
137154

155+
func.func @negative_matmul_to_matvec(%arg0: memref<?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<?x1xf32>) {
156+
// CHECK-LABEL: @negative_matmul_to_matvec
157+
// CHECK-NOT: linalg.matvec
158+
linalg.matmul indexing_maps = [
159+
affine_map<(d0, d1, d2) -> (d2)>,
160+
affine_map<(d0, d1, d2) -> (d2, d1)>,
161+
affine_map<(d0, d1, d2) -> (d0, d1)>
162+
]
163+
ins(%arg0, %arg1: memref<?xf32>, memref<?x1xf32>) outs(%arg2: memref<?x1xf32>)
164+
return
165+
}
166+
167+
// -----
168+
138169
func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
139170
// CHECK-LABEL: @matmul_to_vecmat
140171
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>

0 commit comments

Comments
 (0)