@@ -363,17 +363,40 @@ func.func @xfer_read_minor_identity_transposed_masked_scalable(
363
363
}
364
364
365
365
///----------------------------------------------------------------------------------------
366
- /// vector.transfer_read
366
+ /// [Pattern: TransferOpReduceRank]
367
+ ///
368
+ /// IN: vector.transfer_read (minor identity map + broadcast)
369
+ /// OUT: vector.transfer_read + vector.broadcast
367
370
///----------------------------------------------------------------------------------------
368
371
/// TODO: Review and categorize
369
372
373
+ // CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims
374
+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x4x2x3xf32> {
375
+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
376
+ // CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
377
+ // CHECK: return %[[BC]] : vector<8x4x2x3xf32>
378
+ func.func @xfer_read_minor_identitiy_bcast_dims (
379
+ %mem: memref <?x?x?x?xf32 >,
380
+ %idx: index ) -> vector <8 x4 x2 x3 xf32 > {
381
+
382
+ %pad = arith.constant 0.000000e+00 : f32
383
+
384
+ %res = vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad {
385
+ in_bounds = [true , true , true , true ],
386
+ permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
387
+ } : memref <?x?x?x?xf32 >, vector <8 x4 x2 x3 xf32 >
388
+
389
+ return %res : vector <8 x4 x2 x3 xf32 >
390
+ }
391
+
370
392
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable
371
393
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
372
394
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
373
395
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
374
396
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
375
397
func.func @xfer_read_minor_identitiy_bcast_dims_scalable (
376
- %mem: memref <?x?x?x?xf32 >, %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
398
+ %mem: memref <?x?x?x?xf32 >,
399
+ %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
377
400
378
401
%pad = arith.constant 0.000000e+00 : f32
379
402
@@ -385,31 +408,57 @@ func.func @xfer_read_minor_identitiy_bcast_dims_scalable(
385
408
return %res : vector <8 x[4 ]x2 x3 xf32 >
386
409
}
387
410
411
+ // CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_with_mask
412
+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>
413
+ // CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1>
414
+ // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32>
415
+ // CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
416
+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
417
+ // CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
418
+ // CHECK: return %[[BC]] : vector<8x4x2x3xf32>
419
+ func.func @xfer_read_minor_identitiy_bcast_dims_with_mask (
420
+ %mem: memref <?x?x?x?xf32 >,
421
+ %mask: vector <4 x3 xi1 >,
422
+ %idx: index ) -> vector <8 x4 x2 x3 xf32 > {
423
+
424
+ %pad = arith.constant 0.000000e+00 : f32
425
+
426
+ %res = vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad , %mask {
427
+ in_bounds = [true , true , true , true ],
428
+ permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
429
+ } : memref <?x?x?x?xf32 >, vector <8 x4 x2 x3 xf32 >
430
+
431
+ return %res : vector <8 x4 x2 x3 xf32 >
432
+ }
433
+
388
434
// Masked version is not supported
389
435
390
436
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_masked
391
437
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
392
- // CHECK-SAME: %[[MASK:.*]]: vector<[4]x3xi1 >
393
- // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32 > {
438
+ // CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1 >
439
+ // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32 > {
394
440
// CHECK-NOT: vector.broadcast
395
- // CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32 > } : vector<[4]x3xi1 > -> vector<8x[4]x2x3xf32 >
441
+ // CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32 > } : vector<4x3xi1 > -> vector<8x4x2x3xf32 >
396
442
func.func @xfer_read_minor_identitiy_bcast_dims_masked (
397
443
%mem: memref <?x?x?x?xf32 >,
398
- %mask: vector <[ 4 ]x 3 x i1 >,
399
- %idx: index ) -> vector <8 x[ 4 ]x 2 x 3 x f32 > {
444
+ %mask: vector <4 x 3 x i1 >,
445
+ %idx: index ) -> vector <8 x 4 x 2 x 3 x f32 > {
400
446
401
447
%pad = arith.constant 0.000000e+00 : f32
402
448
403
449
%res = vector.mask %mask {
404
450
vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad {
405
451
in_bounds = [true , true , true , true ],
406
452
permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
407
- } : memref <?x?x?x?xf32 >, vector <8 x[ 4 ]x 2 x 3 x f32 >
408
- } : vector <[ 4 ]x 3 x i1 > -> vector <8 x[ 4 ]x 2 x 3 x f32 >
453
+ } : memref <?x?x?x?xf32 >, vector <8 x 4 x 2 x 3 x f32 >
454
+ } : vector <4 x 3 x i1 > -> vector <8 x 4 x 2 x 3 x f32 >
409
455
410
- return %res : vector <8 x[ 4 ]x 2 x 3 x f32 >
456
+ return %res : vector <8 x 4 x 2 x 3 x f32 >
411
457
}
412
458
459
+ ///----------------------------------------------------------------------------------------
460
+ // TD sequence
461
+ ///----------------------------------------------------------------------------------------
413
462
module attributes {transform.with_named_sequence } {
414
463
transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
415
464
%f = transform.structured.match ops {[" func.func" ]} in %module_op
0 commit comments