@@ -25,6 +25,51 @@ transform.sequence failures(propagate) {
25
25
26
26
// -----
27
27
28
+ // CHECK-LABEL: @vectorize_matmul_memref
29
+ // CHECK-SAME: %[[A:.*]]: memref<24x12xf32>
30
+ // CHECK-SAME: %[[B:.*]]: memref<12x25xf32>
31
+ // CHECK-SAME: %[[C:.*]]: memref<24x25xf32>
32
+ func.func @vectorize_matmul_memref (%arg0: memref <24 x12 xf32 >,
33
+ %arg1: memref <12 x25 xf32 >,
34
+ %arg2: memref <24 x25 xf32 >) {
35
+ // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
36
+ // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
37
+ // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
38
+ // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
39
+ // CHECK: vector.transfer_write %[[vR]], %[[C]]
40
+ linalg.matmul ins (%arg0 , %arg1 : memref <24 x12 xf32 >, memref <12 x25 xf32 >) outs (%arg2 : memref <24 x25 xf32 >)
41
+ return
42
+ }
43
+
44
+ transform.sequence failures (propagate ) {
45
+ ^bb1 (%arg1: !transform.any_op ):
46
+ %0 = transform.structured.match ops {[" linalg.matmul" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
47
+ %1 = get_closest_isolated_parent %0 : (!transform.any_op ) -> !transform.any_op
48
+ %2 = transform.structured.vectorize %1 : (!transform.any_op ) -> !transform.any_op
49
+ }
50
+
51
+ // -----
52
+
53
+ // CHECK-LABEL: @vectorize_copy_memref
54
+ // CHECK-SAME: %[[A:.*]]: memref<100x100xf32>,
55
+ // CHECK-SAME: %[[B:.*]]: memref<100x100xf32>
56
+ func.func @vectorize_copy_memref (%arg0: memref <100 x100 xf32 >,
57
+ %arg1: memref <100 x100 xf32 >) {
58
+ // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
59
+ // CHECK: vector.transfer_write %[[vA]], %[[B]]
60
+ linalg.copy ins (%arg0 : memref <100 x100 xf32 >) outs (%arg1 : memref <100 x100 xf32 >)
61
+ return
62
+ }
63
+
64
+ transform.sequence failures (propagate ) {
65
+ ^bb1 (%arg1: !transform.any_op ):
66
+ %0 = transform.structured.match ops {[" linalg.copy" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
67
+ %1 = get_closest_isolated_parent %0 : (!transform.any_op ) -> !transform.any_op
68
+ %2 = transform.structured.vectorize %1 : (!transform.any_op ) -> !transform.any_op
69
+ }
70
+
71
+ // -----
72
+
28
73
#map0 = affine_map <()[s0 ] -> (-s0 + 12 , 7 )>
29
74
#map1 = affine_map <()[s0 ] -> (-s0 + 7 )>
30
75
0 commit comments