13
13
14
14
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map
15
15
// CHECK-SAME: %[[VEC:.*]]: vector<4x8xi16>,
16
- // CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>) {
16
+ // CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>
17
17
// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
18
18
// CHECK: vector.transfer_write
19
19
// CHECK-NOT: permutation_map
20
20
// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
21
21
func.func @xfer_write_transposing_permutation_map (
22
22
%vec: vector <4 x8 xi16 >,
23
- %mem: memref <2 x2 x8 x4 xi16 >) {
23
+ %mem: memref <2 x2 x8 x4 xi16 >,
24
+ %idx: index ) {
24
25
25
- %c0 = arith.constant 0 : index
26
- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] {
26
+ vector.transfer_write %vec , %mem [%idx , %idx , %idx , %idx ] {
27
27
in_bounds = [true , true ],
28
28
permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d3 , d2 )>
29
29
} : vector <4 x8 xi16 >, memref <2 x2 x8 x4 xi16 >
30
30
31
31
return
32
32
}
33
33
34
- // Even with out-of-bounds, it is safe to apply this pattern
34
+ // Even with out-of-bounds accesses, it is safe to apply this pattern
35
+
35
36
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_out_of_bounds
36
37
// CHECK-SAME: %[[VEC:.*]]: vector<4x8xi16>,
37
- // CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>) {
38
- // CHECK: %[[C0 :.*]] = arith.constant 0 : index
38
+ // CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>,
39
+ // CHECK-SAME : %[[IDX :.*]]: index) {
39
40
// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
40
41
// Expect the in_bounds attribute to be preserved. Since we don't print it when
41
42
// all flags are "false", it should not appear in the output.
42
43
// CHECK-NOT: in_bounds
43
44
// CHECK: vector.transfer_write
44
45
// CHECK-NOT: permutation_map
45
- // CHECK-SAME: %[[TR]], %[[MEM]][%[[C0 ]], %[[C0 ]], %[[C0 ]], %[[C0 ]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
46
+ // CHECK-SAME: %[[TR]], %[[MEM]][%[[IDX ]], %[[IDX ]], %[[IDX ]], %[[IDX ]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
46
47
func.func @xfer_write_transposing_permutation_map_out_of_bounds (
47
48
%vec: vector <4 x8 xi16 >,
48
- %mem: memref <2 x2 x?x?xi16 >) {
49
+ %mem: memref <2 x2 x?x?xi16 >,
50
+ %idx: index ) {
49
51
50
- %c0 = arith.constant 0 : index
51
- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] {
52
+ vector.transfer_write %vec , %mem [%idx , %idx , %idx , %idx ] {
52
53
in_bounds = [false , false ],
53
54
permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d3 , d2 )>
54
55
} : vector <4 x8 xi16 >, memref <2 x2 x?x?xi16 >
@@ -59,18 +60,19 @@ func.func @xfer_write_transposing_permutation_map_out_of_bounds(
59
60
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable
60
61
// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
61
62
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
62
- // CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>) {
63
+ // CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>
63
64
// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
64
65
// CHECK: vector.transfer_write
65
66
// CHECK-NOT: permutation_map
66
67
// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16>
67
68
func.func @xfer_write_transposing_permutation_map_with_mask_scalable (
68
69
%vec: vector <4 x[8 ]xi16 >,
69
70
%mem: memref <2 x2 x?x4 xi16 >,
70
- %mask: vector <[8 ]x4 xi1 >) {
71
+ %mask: vector <[8 ]x4 xi1 >,
72
+ %idx: index ) {
71
73
72
74
%c0 = arith.constant 0 : index
73
- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ], %mask {
75
+ vector.transfer_write %vec , %mem [%idx , %idx , %idx , %idx ], %mask {
74
76
in_bounds = [true , true ],
75
77
permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d3 , d2 )>
76
78
} : vector <4 x[8 ]xi16 >, memref <2 x2 x?x4 xi16 >
@@ -79,16 +81,18 @@ func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
79
81
}
80
82
81
83
// Masked version is not supported
84
+
82
85
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_masked
83
86
// CHECK-NOT: vector.transpose
84
87
func.func @xfer_write_transposing_permutation_map_masked (
85
88
%vec: vector <4 x8 xi16 >,
86
89
%mem: memref <2 x2 x8 x4 xi16 >,
87
- %mask: vector <8 x4 xi1 >) {
90
+ %mask: vector <8 x4 xi1 >,
91
+ %idx: index ) {
88
92
89
93
%c0 = arith.constant 0 : index
90
94
vector.mask %mask {
91
- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] {
95
+ vector.transfer_write %vec , %mem [%idx , %idx , %idx , %idx ] {
92
96
in_bounds = [true , true ],
93
97
permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d3 , d2 )>
94
98
} : vector <4 x8 xi16 >, memref <2 x2 x8 x4 xi16 >
@@ -128,7 +132,8 @@ func.func @xfer_write_non_transposing_permutation_map(
128
132
return
129
133
}
130
134
131
- // Even with out-of-bounds, it is safe to apply this pattern
135
+ // Even with out-of-bounds accesses, it is safe to apply this pattern
136
+
132
137
// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
133
138
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
134
139
// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
@@ -157,8 +162,7 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
157
162
// CHECK: func.func @permutation_with_mask_xfer_write_scalable(
158
163
// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
159
164
// CHECK-SAME: %[[MEM:.*]]: memref<1x4x?x1xi16>,
160
- // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
161
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
165
+ // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>
162
166
// CHECK: %[[BC_1:.*]] = vector.broadcast %[[VEC]] : vector<4x[8]xi16> to vector<1x4x[8]xi16>
163
167
// CHECK: %[[BC_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x4x[8]xi1>
164
168
// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BC_2]], [1, 2, 0] : vector<1x4x[8]xi1> to vector<4x[8]x1xi1>
@@ -167,18 +171,19 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
167
171
func.func @permutation_with_mask_xfer_write_scalable (
168
172
%vec: vector <4 x[8 ]xi16 >,
169
173
%mem: memref <1 x4 x?x1 xi16 >,
170
- %mask: vector <4 x[8 ]xi1 >){
174
+ %mask: vector <4 x[8 ]xi1 >,
175
+ %idx: index ){
171
176
172
- %c0 = arith.constant 0 : index
173
- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ], %mask {
177
+ vector.transfer_write %vec , %mem [%idx , %idx , %idx , %idx ], %mask {
174
178
in_bounds = [true , true ],
175
179
permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>
176
180
} : vector <4 x[8 ]xi16 >, memref <1 x4 x?x1 xi16 >
177
181
178
182
return
179
183
}
180
184
181
- // transfer_write in MaskOp case not supported.
185
+ // Masked version is not supported
186
+
182
187
// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
183
188
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
184
189
// CHECK-SAME: %[[VEC:.*]]: vector<16xf32>,
@@ -204,18 +209,19 @@ func.func @masked_permutation_xfer_write_fixed_width(
204
209
// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
205
210
// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
206
211
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>,
207
- // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
212
+ // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>
208
213
// CHECK-SAME: -> tensor<?x?x?x?xf32> {
209
214
// CHECK-NOT: vector.transpose
210
215
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
211
216
func.func @masked_permutation_xfer_write_scalable (
212
217
%vec: vector <4 x[8 ]xi16 >,
213
218
%dest: tensor <?x?x?x?xf32 >,
214
- %mask: vector <4 x[8 ]xi1 >) -> tensor <?x?x?x?xf32 > {
219
+ %mask: vector <4 x[8 ]xi1 >,
220
+ %idx: index ) -> tensor <?x?x?x?xf32 > {
215
221
216
222
%c0 = arith.constant 0 : index
217
223
%res = vector.mask %mask {
218
- vector.transfer_write %vec , %dest [%c0 , %c0 , %c0 , %c0 ] {
224
+ vector.transfer_write %vec , %dest [%idx , %idx , %idx , %idx ] {
219
225
in_bounds = [true , true ],
220
226
permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>
221
227
} : vector <4 x[8 ]xi16 >, tensor <?x?x?x?xf32 >
@@ -224,22 +230,23 @@ func.func @masked_permutation_xfer_write_scalable(
224
230
return %res : tensor <?x?x?x?xf32 >
225
231
}
226
232
227
- // transfer_write in MaskOp case not supported.
233
+ // Masked version is not supported
234
+
228
235
// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
229
236
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>
230
237
// CHECK-SAME: %[[VEC:.*]]: vector<14x8x16xf32>
231
- // CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
238
+ // CHECK-SAME: %[[DIM:.*]]: index, %[[ IDX:.*]]: index) -> tensor<?x?x?x?xf32>
232
239
// CHECK-NOT: vector.broadcast
233
240
// CHECK: vector.mask %0 { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
234
241
func.func @masked_non_permutation_xfer_write_fixed_width (
235
242
%dest : tensor <?x?x?x?xf32 >,
236
243
%vec : vector <14 x8 x16 xf32 >,
237
- %dim : index ) -> tensor <?x?x?x?xf32 > {
244
+ %dim : index ,
245
+ %idx: index ) -> tensor <?x?x?x?xf32 > {
238
246
239
- %c0 = arith.constant 0 : index
240
247
%mask = vector.create_mask %dim , %dim , %dim : vector <14 x8 x16 xi1 >
241
248
%res = vector.mask %mask {
242
- vector.transfer_write %vec , %dest [%c0 , %c0 , %c0 , %c0 ] {
249
+ vector.transfer_write %vec , %dest [%idx , %idx , %idx , %idx ] {
243
250
in_bounds = [false , false , true ],
244
251
permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>
245
252
} : vector <14 x8 x16 xf32 >, tensor <?x?x?x?xf32 >
@@ -259,25 +266,23 @@ func.func @masked_non_permutation_xfer_write_fixed_width(
259
266
260
267
// CHECK-LABEL: func.func @permutation_with_mask_xfer_read_fixed_width(
261
268
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
262
- // CHECK-SAME: %[[IDX_1:.*]]: index,
263
- // CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x4x2xf32> {
264
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
269
+ // CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x4x2xf32> {
265
270
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
266
- // CHECK: %[[MASK:.*]] = vector.create_mask %[[IDX_2 ]], %[[IDX_1 ]] : vector<2x4xi1>
267
- // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0 ]], %[[C0 ]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
271
+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2 ]], %[[DIM_1 ]] : vector<2x4xi1>
272
+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX ]], %[[IDX ]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
268
273
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x4xf32> to vector<8x2x4xf32>
269
274
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x4xf32> to vector<8x4x2xf32>
270
275
// CHECK: return %[[TRANSPOSE]] : vector<8x4x2xf32>
271
276
func.func @permutation_with_mask_xfer_read_fixed_width (
272
277
%mem: memref <?x?xf32 >,
273
278
%dim_1: index ,
274
- %dim_2: index ) -> (vector <8 x4 x2 xf32 >) {
279
+ %dim_2: index ,
280
+ %idx: index ) -> (vector <8 x4 x2 xf32 >) {
275
281
276
- %c0 = arith.constant 0 : index
277
- %cst_0 = arith.constant 0.000000e+00 : f32
282
+ %pad = arith.constant 0.000000e+00 : f32
278
283
279
284
%mask = vector.create_mask %dim_2 , %dim_1 : vector <2 x4 xi1 >
280
- %res = vector.transfer_read %mem [%c0 , %c0 ], %cst_0 , %mask {
285
+ %res = vector.transfer_read %mem [%idx , %idx ], %pad , %mask {
281
286
in_bounds = [true , true , true ],
282
287
permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>
283
288
} : memref <?x?xf32 >, vector <8 x4 x2 xf32 >
@@ -287,46 +292,45 @@ func.func @permutation_with_mask_xfer_read_fixed_width(
287
292
288
293
// CHECK-LABEL: func.func @permutation_with_mask_xfer_read_scalable(
289
294
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
290
- // CHECK-SAME: %[[IDX_1:.*]]: index,
291
- // CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x[4]x2xf32> {
292
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
293
- // CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
294
- // CHECK: %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x[4]xi1>
295
- // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
295
+ // CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x[4]x2xf32> {
296
+ // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
297
+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x[4]xi1>
298
+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
296
299
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32>
297
300
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
298
301
// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
299
302
func.func @permutation_with_mask_xfer_read_scalable (
300
303
%mem: memref <?x?xf32 >,
301
304
%dim_1: index ,
302
- %dim_2: index ) -> (vector <8 x[4 ]x2 xf32 >) {
305
+ %dim_2: index ,
306
+ %idx: index ) -> (vector <8 x[4 ]x2 xf32 >) {
303
307
304
- %c0 = arith.constant 0 : index
305
- %cst_0 = arith.constant 0.000000e+00 : f32
308
+ %pad = arith.constant 0.000000e+00 : f32
306
309
307
310
%mask = vector.create_mask %dim_2 , %dim_1 : vector <2 x[4 ]xi1 >
308
- %res = vector.transfer_read %mem [%c0 , %c0 ], %cst_0 , %mask {
311
+ %res = vector.transfer_read %mem [%idx , %idx ], %pad , %mask {
309
312
in_bounds = [true , true , true ],
310
313
permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>
311
314
} : memref <?x?xf32 >, vector <8 x[4 ]x2 xf32 >
312
315
313
316
return %res : vector <8 x[4 ]x2 xf32 >
314
317
}
315
318
316
- // transfer_read in MaskOp case not supported.
319
+ // Masked version is not supported
320
+
317
321
// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
318
322
// CHECK-SAME: %[[DEST:.*]]: tensor<?x1xf32>,
319
323
// CHECK-SAME: %[[MASK:.*]]: vector<4x1xi1>
320
324
// CHECK-NOT: vector.transpose
321
325
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
322
326
func.func @masked_permutation_xfer_read_fixed_width (
323
327
%dest: tensor <?x1 xf32 >,
324
- %mask : vector <4 x1 xi1 >) {
328
+ %mask : vector <4 x1 xi1 >,
329
+ %idx: index ) {
325
330
326
- %cst = arith.constant 0.000000e+00 : f32
327
- %c0 = arith.constant 0 : index
331
+ %pad = arith.constant 0.000000e+00 : f32
328
332
%3 = vector.mask %mask {
329
- vector.transfer_read %dest [%c0 , %c0 ], %cst {
333
+ vector.transfer_read %dest [%idx , %idx ], %pad {
330
334
permutation_map = affine_map <(d0 , d1 ) -> (d1 , 0 , d0 )>
331
335
} : tensor <?x1 xf32 >, vector <1 x4 x4 xf32 >
332
336
} : vector <4 x1 xi1 > -> vector <1 x4 x4 xf32 >
@@ -337,18 +341,18 @@ func.func @masked_permutation_xfer_read_fixed_width(
337
341
338
342
// CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable(
339
343
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
340
- // CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
344
+ // CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>
341
345
// CHECK-NOT: vector.transpose
342
346
// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
343
347
func.func @masked_permutation_xfer_read_scalable (
344
348
%dest: tensor <?x?xf32 >,
345
- %mask : vector <2 x[4 ]xi1 >) -> vector <8 x[4 ]x2 xf32 > {
349
+ %mask : vector <2 x[4 ]xi1 >,
350
+ %idx: index ) -> vector <8 x[4 ]x2 xf32 > {
346
351
347
- %c0 = arith.constant 0 : index
348
- %cst_0 = arith.constant 0.000000e+00 : f32
352
+ %pad = arith.constant 0.000000e+00 : f32
349
353
350
354
%res = vector.mask %mask {
351
- vector.transfer_read %dest [%c0 , %c0 ], %cst_0 {
355
+ vector.transfer_read %dest [%idx , %idx ], %pad {
352
356
in_bounds = [true , true , true ],
353
357
permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>
354
358
} : tensor <?x?xf32 >, vector <8 x[4 ]x2 xf32 >
@@ -377,41 +381,41 @@ module attributes {transform.with_named_sequence} {
377
381
378
382
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
379
383
// CHECK: func.func @transfer_read_reduce_rank_scalable(
380
- // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
381
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
382
- // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
384
+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
385
+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
383
386
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
384
387
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
385
388
func.func @transfer_read_reduce_rank_scalable (
386
- %mem: memref <?x?x?x?xf32 >) -> vector <8 x[4 ]x2 x3 xf32 > {
389
+ %mem: memref <?x?x?x?xf32 >, %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
387
390
388
- %c0 = arith.constant 0 : index
389
- %cst_0 = arith.constant 0.000000e+00 : f32
391
+ %pad = arith.constant 0.000000e+00 : f32
390
392
391
- %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst_0 {
393
+ %res = vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad {
392
394
in_bounds = [true , true , true , true ],
393
395
permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
394
396
} : memref <?x?x?x?xf32 >, vector <8 x[4 ]x2 x3 xf32 >
395
397
396
398
return %res : vector <8 x[4 ]x2 x3 xf32 >
397
399
}
398
400
399
- // Masked case not supported.
401
+ // Masked version is not supported
402
+
400
403
// CHECK-LABEL: func.func @masked_transfer_read_reduce_rank(
401
404
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
402
- // CHECK-SAME: %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
405
+ // CHECK-SAME: %[[DIM:.*]]: index,
406
+ // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
403
407
// CHECK-NOT: vector.broadcast
404
408
// CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
405
409
func.func @masked_transfer_read_reduce_rank (
406
410
%mem: memref <?x?x?x?xf32 >,
407
- %dim: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
411
+ %dim: index ,
412
+ %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
408
413
409
- %c0 = arith.constant 0 : index
410
- %cst_0 = arith.constant 0.000000e+00 : f32
414
+ %pad = arith.constant 0.000000e+00 : f32
411
415
%mask = vector.create_mask %dim , %dim: vector <[4 ]x3 xi1 >
412
416
413
417
%res = vector.mask %mask {
414
- vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst_0 {
418
+ vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad {
415
419
in_bounds = [true , true , true , true ],
416
420
permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
417
421
} : memref <?x?x?x?xf32 >, vector <8 x[4 ]x2 x3 xf32 >
0 commit comments