1
+ // RUN: mlir-opt --transform-interpreter --canonicalize --split-input-file %s | FileCheck %s
2
+
3
+ module attributes {transform.with_named_sequence } {
4
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
5
+ %0 = transform.structured.match ops {[" linalg.matmul" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
6
+ %tile_sizes , %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0 , target_size = 9 } : (!transform.any_op ) -> !transform.any_op
7
+ %linalg_splits , %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0 , multiway } : !transform.any_op , !transform.any_op
8
+ transform.foreach %linalg_splits , %tile_sizes : !transform.any_op , !transform.any_op {
9
+ ^bb1 (%linalg_split: !transform.any_op , %tile_size: !transform.any_op ):
10
+ %tiled_linalg_split , %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size ] : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
11
+ transform.yield
12
+ }
13
+ transform.yield
14
+ }
15
+ }
16
+
17
+ func.func @continuous_tile_linalg_matmul (
18
+ %arg0: tensor <25 x34 xf32 >, %arg1: tensor <34 x25 xf32 >, %arg2: tensor <25 x25 xf32 >)
19
+ -> tensor <25 x25 xf32 > {
20
+ %0 = linalg.matmul ins (%arg0 , %arg1: tensor <25 x34 xf32 >, tensor <34 x25 xf32 >)
21
+ outs (%arg2: tensor <25 x25 xf32 >)
22
+ -> tensor <25 x25 xf32 >
23
+
24
+ return %0 : tensor <25 x25 xf32 >
25
+ }
26
+
27
+ // CHECK-LABEL: @continuous_tile_linalg_matmul
28
+ // CHECK-SAME: (%[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>) -> tensor<25x25xf32> {
29
+ // CHECK: %[[C18:.+]] = arith.constant 18 : index
30
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
31
+ // CHECK: %[[C9:.+]] = arith.constant 9 : index
32
+ // CHECK: %[[XSIN18:.+]] = tensor.extract_slice %[[IN1]][0, 0] [18, 34] [1, 1] : tensor<25x34xf32> to tensor<18x34xf32>
33
+ // CHECK: %[[XSOUT18:.+]] = tensor.extract_slice %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<25x25xf32> to tensor<18x25xf32>
34
+ // CHECK: %[[R0:.+]] = scf.for %[[IDX:.+]] = %[[C0]] to %[[C18]] step %[[C9]] iter_args(%[[XSOUT18ARG:.+]] = %[[XSOUT18]]) -> (tensor<18x25xf32>) {
35
+ // CHECK: %[[XSIN19:.+]] = tensor.extract_slice %[[XSIN18]][%[[IDX]], 0] [9, 34] [1, 1] : tensor<18x34xf32> to tensor<9x34xf32>
36
+ // CHECK: %[[XSOUT9:.+]] = tensor.extract_slice %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<18x25xf32> to tensor<9x25xf32>
37
+ // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[XSIN19]], %[[IN2]] : tensor<9x34xf32>, tensor<34x25xf32>) outs(%[[XSOUT9]] : tensor<9x25xf32>) -> tensor<9x25xf32>
38
+ // CHECK: %[[INS9:.+]] = tensor.insert_slice %[[MATMUL]] into %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<9x25xf32> into tensor<18x25xf32>
39
+ // CHECK: scf.yield %[[INS9]] : tensor<18x25xf32>
40
+ // CHECK: }
41
+ // CHECK: %[[INS:.+]] = tensor.insert_slice %[[R0]] into %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<18x25xf32> into tensor<25x25xf32>
42
+ // CHECK: %[[XS1:.+]] = tensor.extract_slice %[[IN1]][18, 0] [7, 34] [1, 1] : tensor<25x34xf32> to tensor<7x34xf32>
43
+ // CHECK: %[[XS2:.+]] = tensor.extract_slice %[[INS]][18, 0] [7, 25] [1, 1] : tensor<25x25xf32> to tensor<7x25xf32>
44
+ // CHECK: %[[XS3:.+]] = tensor.extract_slice %[[XS1]][0, 0] [4, 34] [1, 1] : tensor<7x34xf32> to tensor<4x34xf32>
45
+ // CHECK: %[[XS4:.+]] = tensor.extract_slice %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<7x25xf32> to tensor<4x25xf32>
46
+ // CHECK: %[[R1:.+]] = linalg.matmul ins(%[[XS3]], %[[IN2]] : tensor<4x34xf32>, tensor<34x25xf32>) outs(%[[XS4]] : tensor<4x25xf32>) -> tensor<4x25xf32>
47
+ // CHECK: %[[INS5:.+]] = tensor.insert_slice %[[R1]] into %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<4x25xf32> into tensor<7x25xf32>
48
+ // CHECK: %[[XS6:.+]] = tensor.extract_slice %[[XS1]][4, 0] [3, 34] [1, 1] : tensor<7x34xf32> to tensor<3x34xf32>
49
+ // CHECK: %[[XS7:.+]] = tensor.extract_slice %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<7x25xf32> to tensor<3x25xf32>
50
+ // CHECK: %[[XS8:.+]] = tensor.extract_slice %[[XS6]][0, 0] [2, 34] [1, 1] : tensor<3x34xf32> to tensor<2x34xf32>
51
+ // CHECK: %[[XS9:.+]] = tensor.extract_slice %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<3x25xf32> to tensor<2x25xf32>
52
+ // CHECK: %[[R2:.+]] = linalg.matmul ins(%[[XS8]], %[[IN2]] : tensor<2x34xf32>, tensor<34x25xf32>) outs(%[[XS9]] : tensor<2x25xf32>) -> tensor<2x25xf32>
53
+ // CHECK: %[[INS10:.+]] = tensor.insert_slice %[[R2]] into %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<2x25xf32> into tensor<3x25xf32>
54
+ // CHECK: %[[XS11:.+]] = tensor.extract_slice %[[XS6]][2, 0] [1, 34] [1, 1] : tensor<3x34xf32> to tensor<1x34xf32>
55
+ // CHECK: %[[XS12:.+]] = tensor.extract_slice %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<3x25xf32> to tensor<1x25xf32>
56
+ // CHECK: %[[R3:.+]] = linalg.matmul ins(%[[XS11]], %[[IN2]] : tensor<1x34xf32>, tensor<34x25xf32>) outs(%[[XS12]] : tensor<1x25xf32>) -> tensor<1x25xf32>
57
+ // CHECK: %[[INS13:.+]] = tensor.insert_slice %[[R3]] into %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<1x25xf32> into tensor<3x25xf32>
58
+ // CHECK: %[[INS14:.+]] = tensor.insert_slice %[[INS13]] into %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<3x25xf32> into tensor<7x25xf32>
59
+ // CHECK: %[[INS15:.+]] = tensor.insert_slice %[[INS14]] into %[[INS]][18, 0] [7, 25] [1, 1] : tensor<7x25xf32> into tensor<25x25xf32>
60
+ // CHECK: return %[[INS15]] : tensor<25x25xf32>
61
+
62
+ // -----
63
+
64
+ module attributes {transform.with_named_sequence } {
65
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
66
+ %0 = transform.structured.match ops {[" linalg.matmul" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
67
+ %tile_sizes , %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0 , target_size = 9 } : (!transform.any_op ) -> !transform.param <i64 >
68
+ %linalg_splits , %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0 , multiway } : !transform.any_op , !transform.param <i64 >
69
+ transform.foreach %linalg_splits , %tile_sizes : !transform.any_op , !transform.param <i64 > {
70
+ ^bb1 (%linalg_split: !transform.any_op , %tile_size: !transform.param <i64 >):
71
+ %tiled_linalg_split , %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size ] : (!transform.any_op , !transform.param <i64 >) -> (!transform.any_op , !transform.any_op )
72
+ transform.yield
73
+ }
74
+ transform.yield
75
+ }
76
+ }
77
+
78
+ func.func @continuous_tile_static_linalg_matmul (
79
+ %arg0: tensor <25 x34 xf32 >, %arg1: tensor <34 x25 xf32 >, %arg2: tensor <25 x25 xf32 >)
80
+ -> tensor <25 x25 xf32 > {
81
+ %0 = linalg.matmul ins (%arg0 , %arg1: tensor <25 x34 xf32 >, tensor <34 x25 xf32 >)
82
+ outs (%arg2: tensor <25 x25 xf32 >)
83
+ -> tensor <25 x25 xf32 >
84
+
85
+ return %0 : tensor <25 x25 xf32 >
86
+ }
87
+
88
+ // CHECK-LABEL: @continuous_tile_static_linalg_matmul
89
+ // CHECK-SAME: (%[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>) -> tensor<25x25xf32> {
90
+ // CHECK: %[[C9:.+]] = arith.constant 9 : index
91
+ // CHECK: %[[C18:.+]] = arith.constant 18 : index
92
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
93
+ // CHECK: %[[XSIN18:.+]] = tensor.extract_slice %[[IN1]][0, 0] [18, 34] [1, 1] : tensor<25x34xf32> to tensor<18x34xf32>
94
+ // CHECK: %[[XSOUT18:.+]] = tensor.extract_slice %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<25x25xf32> to tensor<18x25xf32>
95
+ // CHECK: %[[R0:.+]] = scf.for %[[IDX:.+]] = %[[C0]] to %[[C18]] step %[[C9]] iter_args(%[[XSOUT18ARG:.+]] = %[[XSOUT18]]) -> (tensor<18x25xf32>) {
96
+ // CHECK: %[[XSIN19:.+]] = tensor.extract_slice %[[XSIN18]][%[[IDX]], 0] [9, 34] [1, 1] : tensor<18x34xf32> to tensor<9x34xf32>
97
+ // CHECK: %[[XSOUT9:.+]] = tensor.extract_slice %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<18x25xf32> to tensor<9x25xf32>
98
+ // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[XSIN19]], %[[IN2]] : tensor<9x34xf32>, tensor<34x25xf32>) outs(%[[XSOUT9]] : tensor<9x25xf32>) -> tensor<9x25xf32>
99
+ // CHECK: %[[INS9:.+]] = tensor.insert_slice %[[MATMUL]] into %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<9x25xf32> into tensor<18x25xf32>
100
+ // CHECK: scf.yield %[[INS9]] : tensor<18x25xf32>
101
+ // CHECK: }
102
+ // CHECK: %[[INS:.+]] = tensor.insert_slice %[[R0]] into %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<18x25xf32> into tensor<25x25xf32>
103
+ // CHECK: %[[XS1:.+]] = tensor.extract_slice %[[IN1]][18, 0] [7, 34] [1, 1] : tensor<25x34xf32> to tensor<7x34xf32>
104
+ // CHECK: %[[XS2:.+]] = tensor.extract_slice %[[INS]][18, 0] [7, 25] [1, 1] : tensor<25x25xf32> to tensor<7x25xf32>
105
+ // CHECK: %[[XS3:.+]] = tensor.extract_slice %[[XS1]][0, 0] [4, 34] [1, 1] : tensor<7x34xf32> to tensor<4x34xf32>
106
+ // CHECK: %[[XS4:.+]] = tensor.extract_slice %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<7x25xf32> to tensor<4x25xf32>
107
+ // CHECK: %[[R1:.+]] = linalg.matmul ins(%[[XS3]], %[[IN2]] : tensor<4x34xf32>, tensor<34x25xf32>) outs(%[[XS4]] : tensor<4x25xf32>) -> tensor<4x25xf32>
108
+ // CHECK: %[[INS5:.+]] = tensor.insert_slice %[[R1]] into %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<4x25xf32> into tensor<7x25xf32>
109
+ // CHECK: %[[XS6:.+]] = tensor.extract_slice %[[XS1]][4, 0] [3, 34] [1, 1] : tensor<7x34xf32> to tensor<3x34xf32>
110
+ // CHECK: %[[XS7:.+]] = tensor.extract_slice %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<7x25xf32> to tensor<3x25xf32>
111
+ // CHECK: %[[XS8:.+]] = tensor.extract_slice %[[XS6]][0, 0] [2, 34] [1, 1] : tensor<3x34xf32> to tensor<2x34xf32>
112
+ // CHECK: %[[XS9:.+]] = tensor.extract_slice %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<3x25xf32> to tensor<2x25xf32>
113
+ // CHECK: %[[R2:.+]] = linalg.matmul ins(%[[XS8]], %[[IN2]] : tensor<2x34xf32>, tensor<34x25xf32>) outs(%[[XS9]] : tensor<2x25xf32>) -> tensor<2x25xf32>
114
+ // CHECK: %[[INS10:.+]] = tensor.insert_slice %[[R2]] into %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<2x25xf32> into tensor<3x25xf32>
115
+ // CHECK: %[[XS11:.+]] = tensor.extract_slice %[[XS6]][2, 0] [1, 34] [1, 1] : tensor<3x34xf32> to tensor<1x34xf32>
116
+ // CHECK: %[[XS12:.+]] = tensor.extract_slice %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<3x25xf32> to tensor<1x25xf32>
117
+ // CHECK: %[[R3:.+]] = linalg.matmul ins(%[[XS11]], %[[IN2]] : tensor<1x34xf32>, tensor<34x25xf32>) outs(%[[XS12]] : tensor<1x25xf32>) -> tensor<1x25xf32>
118
+ // CHECK: %[[INS13:.+]] = tensor.insert_slice %[[R3]] into %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<1x25xf32> into tensor<3x25xf32>
119
+ // CHECK: %[[INS14:.+]] = tensor.insert_slice %[[INS13]] into %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<3x25xf32> into tensor<7x25xf32>
120
+ // CHECK: %[[INS15:.+]] = tensor.insert_slice %[[INS14]] into %[[INS]][18, 0] [7, 25] [1, 1] : tensor<7x25xf32> into tensor<25x25xf32>
121
+ // CHECK: return %[[INS15]] : tensor<25x25xf32>
122
+
123
+ // -----
124
+
125
+ module attributes {transform.with_named_sequence } {
126
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
127
+ %0 = transform.structured.match ops {[" linalg.matmul" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
128
+ %tile_sizes , %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0 , target_size = 9 } : (!transform.any_op ) -> !transform.any_op
129
+ %linalg_splits , %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0 , multiway } : !transform.any_op , !transform.any_op
130
+ transform.foreach %linalg_splits , %tile_sizes {zip_shortest } : !transform.any_op , !transform.any_op {
131
+ ^bb1 (%linalg_split: !transform.any_op , %tile_size: !transform.any_op ):
132
+ %tiled_linalg_split , %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size ] : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
133
+ transform.yield
134
+ }
135
+ transform.yield
136
+ }
137
+ }
138
+
139
+ func.func @continuous_tile_dynamic_linalg_matmul (
140
+ %arg0: tensor <?x?xf32 >, %arg1: tensor <?x?xf32 >, %arg2: tensor <?x?xf32 >)
141
+ -> tensor <?x?xf32 > {
142
+ %0 = linalg.matmul ins (%arg0 , %arg1: tensor <?x?xf32 >, tensor <?x?xf32 >)
143
+ outs (%arg2: tensor <?x?xf32 >)
144
+ -> tensor <?x?xf32 >
145
+
146
+ return %0 : tensor <?x?xf32 >
147
+ }
148
+
149
+ // CHECK: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 9) * 9, s1)>
150
+ // CHECK: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2] -> (((s0 mod 9) floordiv 8) * 8, s1 - s2)>
151
+ // CHECK: #[[$MAP6:.*]] = affine_map<()[s0, s1, s2, s3] -> ((((s0 mod 9) mod 8) floordiv 4) * 4, s1 - s2 - s3)>
152
+ // CHECK: #[[$MAP9:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> ((((s0 mod 9) mod 4) floordiv 2) * 2, s1 - s2 - s3 - s4)>
153
+ // CHECK: #[[$MAP12:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> ((s0 mod 9) mod 2, s1 - s2 - s3 - s4 - s5)>
154
+ // CHECK-LABEL: @continuous_tile_dynamic_linalg_matmul
155
+ // CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
156
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
157
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
158
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
159
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
160
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
161
+ // CHECK: %[[AM0:.*]] = affine.min #[[$MAP0]]()[%{{.*}}, %{{.*}}]
162
+ // CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM0]] step %[[C9]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
163
+ // CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
164
+ // CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
165
+ // CHECK: %[[AM4:.*]] = affine.min #[[$MAP3]]()[%{{.*}}, %{{.*}}, %[[AM0]]]
166
+ // CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM4]] step %[[C8]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
167
+ // CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
168
+ // CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
169
+ // CHECK: %[[AM8:.*]] = affine.min #[[$MAP6]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]]]
170
+ // CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM8]] step %[[C4]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
171
+ // CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
172
+ // CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
173
+ // CHECK: %[[AM12:.*]] = affine.min #[[$MAP9]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]], %[[AM8]]]
174
+ // CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM12]] step %[[C2]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
175
+ // CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
176
+ // CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
177
+ // CHECK: %[[AM16:.*]] = affine.min #[[$MAP12]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]], %[[AM8]], %[[AM12]]]
178
+ // CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM16]] step %[[C1]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
179
+ // CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<1x?xf32>) -> tensor<1x?xf32>
180
+ // CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor<?x?xf32>
0 commit comments