@@ -164,6 +164,37 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
164
164
return %0: vector <1 x1 x2 x16 xf32 >
165
165
}
166
166
167
+ // -----
168
+
169
+ // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
170
+ // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
171
+ // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
172
+
173
+ // CHECK-LABEL: not_insert_cast_for_contraction_under_mask
174
+ // CHECK: %[[MASK:.+]] = vector.constant_mask
175
+ // CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
176
+ // CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
177
+ // CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
178
+ // CHECK: return %[[RET]] : vector<1x16x16xf32>
179
+
180
+ #contraction_accesses0 = [
181
+ affine_map <(l , i , j , k ) -> (l , i , k )>,
182
+ affine_map <(l , i , j , k ) -> (l , k , j )>,
183
+ affine_map <(l , i , j , k ) -> (l , i , j )>
184
+ ]
185
+ #contraction_trait0 = {
186
+ indexing_maps = #contraction_accesses0 ,
187
+ iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]
188
+ }
189
+
190
+ func.func @not_insert_cast_for_contraction_under_mask (%arg0: vector <1 x16 x8 xf32 >, %arg1: vector <1 x8 x16 xf32 >, %arg2: vector <1 x16 x16 xf32 >) -> vector <1 x16 x16 xf32 > {
191
+ %mask = vector.constant_mask [1 , 15 , 15 , 8 ] : vector <1 x16 x16 x8 xi1 >
192
+ %0 = vector.mask %mask {
193
+ vector.contract #contraction_trait0 %arg0 , %arg1 , %arg2 : vector <1 x16 x8 xf32 >, vector <1 x8 x16 xf32 > into vector <1 x16 x16 xf32 >
194
+ } : vector <1 x16 x16 x8 xi1 > -> vector <1 x16 x16 xf32 >
195
+ return %0 : vector <1 x16 x16 xf32 >
196
+ }
197
+
167
198
// -----
168
199
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
169
200
func.func @cast_away_extract_strided_slice_leading_one_dims (%arg0: vector <1 x8 x8 xf16 >) -> vector <1 x1 x8 xf16 > {
@@ -253,6 +284,24 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
253
284
254
285
// -----
255
286
287
+ // CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
288
+ // CHECK: %[[MASK:.+]] = vector.constant_mask
289
+ // CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
290
+ // CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
291
+ // CHECK-SAME: vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
292
+ // CHECK: return %[[RET]] : vector<1x4xf16>
293
+ func.func @not_insert_cast_fo4_transfer_read_under_mask (%arg0: memref <1 x1 x4 xf16 >) -> vector <1 x4 xf16 > {
294
+ %c0 = arith.constant 0 : index
295
+ %f0 = arith.constant 0. : f16
296
+ %mask = vector.constant_mask [1 , 3 ] : vector <1 x4 xi1 >
297
+ %ret = vector.mask %mask {
298
+ vector.transfer_read %arg0 [%c0 , %c0 , %c0 ], %f0 {in_bounds = [true , true ]} : memref <1 x1 x4 xf16 >, vector <1 x4 xf16 >
299
+ } : vector <1 x4 xi1 > -> vector <1 x4 xf16 >
300
+ return %ret: vector <1 x4 xf16 >
301
+ }
302
+
303
+ // -----
304
+
256
305
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
257
306
func.func @cast_away_transfer_write_leading_one_dims (%arg0: memref <1 x4 x8 x16 xf16 >, %arg1: vector <1 x4 xf16 >) {
258
307
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -286,6 +335,23 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
286
335
287
336
// -----
288
337
338
+ // CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
339
+ // CHECK: %[[MASK:.+]] = vector.constant_mask
340
+ // CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
341
+ // CHECK: vector.mask %[[CASTED_MASK]] {
342
+ // CHECK-SAME: vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
343
+ // CHECK: return
344
+ func.func @not_insert_cast_for_transfer_write_under_mask (%arg0: memref <1 x1 x4 xf16 >, %arg1: vector <1 x4 xf16 >) {
345
+ %c0 = arith.constant 0 : index
346
+ %mask = vector.constant_mask [1 , 3 ] : vector <1 x4 xi1 >
347
+ vector.mask %mask {
348
+ vector.transfer_write %arg1 , %arg0 [%c0 , %c0 , %c0 ] {in_bounds = [true , true ]} : vector <1 x4 xf16 >, memref <1 x1 x4 xf16 >
349
+ } : vector <1 x4 xi1 >
350
+ return
351
+ }
352
+
353
+ // -----
354
+
289
355
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
290
356
// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
291
357
func.func @cast_away_nontrivial_map_masked_transfer_write (%arg0: memref <1 x4 x8 xf16 >, %arg1: vector <1 x1 x4 xf16 >, %arg2: vector <1 x4 x1 xi1 >) {
0 commit comments