@@ -363,17 +363,39 @@ func.func @xfer_read_minor_identity_transposed_masked_scalable(
363
363
}
364
364
365
365
///----------------------------------------------------------------------------------------
366
- /// vector.transfer_read
366
+ /// [Pattern: TransferOpReduceRank]
367
+ ///
368
+ /// IN: vector.transfer_read (minor identity map + broadcast)
369
+ /// OUT: vector.transfer_read + vector.broadcast
367
370
///----------------------------------------------------------------------------------------
368
- /// TODO: Review and categorize
371
+
372
+ // CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims
373
+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x4x2x3xf32> {
374
+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
375
+ // CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
376
+ // CHECK: return %[[BC]] : vector<8x4x2x3xf32>
377
+ func.func @xfer_read_minor_identitiy_bcast_dims (
378
+ %mem: memref <?x?x?x?xf32 >,
379
+ %idx: index ) -> vector <8 x4 x2 x3 xf32 > {
380
+
381
+ %pad = arith.constant 0.000000e+00 : f32
382
+
383
+ %res = vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad {
384
+ in_bounds = [true , true , true , true ],
385
+ permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
386
+ } : memref <?x?x?x?xf32 >, vector <8 x4 x2 x3 xf32 >
387
+
388
+ return %res : vector <8 x4 x2 x3 xf32 >
389
+ }
369
390
370
391
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable
371
392
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
372
393
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
373
394
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
374
395
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
375
396
func.func @xfer_read_minor_identitiy_bcast_dims_scalable (
376
- %mem: memref <?x?x?x?xf32 >, %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
397
+ %mem: memref <?x?x?x?xf32 >,
398
+ %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
377
399
378
400
%pad = arith.constant 0.000000e+00 : f32
379
401
@@ -385,31 +407,80 @@ func.func @xfer_read_minor_identitiy_bcast_dims_scalable(
385
407
return %res : vector <8 x[4 ]x2 x3 xf32 >
386
408
}
387
409
410
+ // CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_with_mask
411
+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>
412
+ // CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1>
413
+ // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32>
414
+ // CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
415
+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
416
+ // CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
417
+ // CHECK: return %[[BC]] : vector<8x4x2x3xf32>
418
+ func.func @xfer_read_minor_identitiy_bcast_dims_with_mask (
419
+ %mem: memref <?x?x?x?xf32 >,
420
+ %mask: vector <4 x3 xi1 >,
421
+ %idx: index ) -> vector <8 x4 x2 x3 xf32 > {
422
+
423
+ %pad = arith.constant 0.000000e+00 : f32
424
+
425
+ %res = vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad , %mask {
426
+ in_bounds = [true , true , true , true ],
427
+ permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
428
+ } : memref <?x?x?x?xf32 >, vector <8 x4 x2 x3 xf32 >
429
+
430
+ return %res : vector <8 x4 x2 x3 xf32 >
431
+ }
432
+
433
+ // CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_with_mask_scalable
434
+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>
435
+ // CHECK-SAME: %[[MASK:.*]]: vector<[4]x3xi1>
436
+ // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32>
437
+ // CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
438
+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
439
+ // CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
440
+ // CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
441
+ func.func @xfer_read_minor_identitiy_bcast_dims_with_mask_scalable (
442
+ %mem: memref <?x?x?x?xf32 >,
443
+ %mask: vector <[4 ]x3 xi1 >,
444
+ %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
445
+
446
+ %pad = arith.constant 0.000000e+00 : f32
447
+
448
+ %res = vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad , %mask {
449
+ in_bounds = [true , true , true , true ],
450
+ permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
451
+ } : memref <?x?x?x?xf32 >, vector <8 x[4 ]x2 x3 xf32 >
452
+
453
+ return %res : vector <8 x[4 ]x2 x3 xf32 >
454
+ }
455
+
388
456
// Masked version is not supported
389
457
390
458
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_masked
391
459
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
392
- // CHECK-SAME: %[[MASK:.*]]: vector<[4]x3xi1 >
393
- // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32 > {
460
+ // CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1 >
461
+ // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32 > {
394
462
// CHECK-NOT: vector.broadcast
395
- // CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32 > } : vector<[4]x3xi1 > -> vector<8x[4]x2x3xf32 >
463
+ // CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32 > } : vector<4x3xi1 > -> vector<8x4x2x3xf32 >
396
464
func.func @xfer_read_minor_identitiy_bcast_dims_masked (
397
465
%mem: memref <?x?x?x?xf32 >,
398
- %mask: vector <[ 4 ]x 3 x i1 >,
399
- %idx: index ) -> vector <8 x[ 4 ]x 2 x 3 x f32 > {
466
+ %mask: vector <4 x 3 x i1 >,
467
+ %idx: index ) -> vector <8 x 4 x 2 x 3 x f32 > {
400
468
401
469
%pad = arith.constant 0.000000e+00 : f32
402
470
403
471
%res = vector.mask %mask {
404
472
vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad {
405
473
in_bounds = [true , true , true , true ],
406
474
permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
407
- } : memref <?x?x?x?xf32 >, vector <8 x[ 4 ]x 2 x 3 x f32 >
408
- } : vector <[ 4 ]x 3 x i1 > -> vector <8 x[ 4 ]x 2 x 3 x f32 >
475
+ } : memref <?x?x?x?xf32 >, vector <8 x 4 x 2 x 3 x f32 >
476
+ } : vector <4 x 3 x i1 > -> vector <8 x 4 x 2 x 3 x f32 >
409
477
410
- return %res : vector <8 x[ 4 ]x 2 x 3 x f32 >
478
+ return %res : vector <8 x 4 x 2 x 3 x f32 >
411
479
}
412
480
481
+ ///----------------------------------------------------------------------------------------
482
+ // TD sequence
483
+ ///----------------------------------------------------------------------------------------
413
484
module attributes {transform.with_named_sequence } {
414
485
transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
415
486
%f = transform.structured.match ops {[" func.func" ]} in %module_op
0 commit comments