@@ -468,23 +468,66 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
468
468
469
469
// -----
470
470
471
- // CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
472
- // CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[SZ0 :.*]]: index)
473
- func.func @fold_dynamic_subview_with_memref_load_store_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index , %sz0: index ) -> f32 {
471
+ // CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
472
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3 :.*]]: index) -> f32
473
+ func.func @fold_dynamic_subview_with_memref_load_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index , %sz0: index ) -> f32 {
474
474
%c0 = arith.constant 0 : index
475
475
%expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] output_shape [1 , 16 , %sz0 , 1 ] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
476
476
%0 = memref.load %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
477
477
return %0 : f32
478
478
}
479
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
480
- // CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 16, %[[SZ0]], 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
481
- // CHECK: %[[VAL_0:.*]] = memref.load %[[EXPAND_SHAPE]][%[[C0]], %[[ARG1]], %[[ARG2]], %[[C0]]] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
482
- // CHECK: return %[[VAL_0]] : f32
479
+ // CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
480
+ // CHECK-NEXT: return %[[VAL1]] : f32
483
481
484
482
// -----
485
483
486
- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
487
- // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
484
+ // CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
485
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
486
+ func.func @fold_dynamic_subview_with_memref_store_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index , %sz0 : index ) {
487
+ %c0 = arith.constant 0 : index
488
+ %c1f32 = arith.constant 1.0 : f32
489
+ %expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] output_shape [1 , 16 , %sz0 , 1 ] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
490
+ memref.store %c1f32 , %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
491
+ return
492
+ }
493
+ // CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
494
+ // CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
495
+ // CHECK-NEXT: return
496
+
497
+ // -----
498
+
499
+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
500
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
501
+ // CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
502
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
503
+ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim (%alloc: memref <2048 x16 xf32 >, %c10: index , %c5: index , %c0: index , %sz0: index ) {
504
+ %subview = memref.subview %alloc [%c5 , 0 ] [%c10 , 16 ] [1 , 1 ] : memref <2048 x16 xf32 > to memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>>
505
+ %expand_shape = memref.expand_shape %subview [[0 ], [1 , 2 , 3 ]] output_shape [1 , 16 , %sz0 , 1 ] : memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>> into memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
506
+ %dim = memref.dim %expand_shape , %c0 : memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
507
+
508
+ affine.for %arg6 = 0 to %dim step 64 {
509
+ affine.for %arg7 = 0 to 16 step 16 {
510
+ %dummy_load = affine.load %expand_shape [%arg6 , 0 , %arg7 , %arg7 ] : memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
511
+ affine.store %dummy_load , %subview [%arg6 , %arg7 ] : memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>>
512
+ }
513
+ }
514
+ return
515
+ }
516
+ // CHECK-NEXT: memref.subview
517
+ // CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
518
+ // CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
519
+ // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
520
+ // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
521
+ // CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
522
+ // CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]]
523
+ // CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
524
+ // CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
525
+ // CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
526
+
527
+ // -----
528
+
529
+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)>
530
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
488
531
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
489
532
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
490
533
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape (%arg0: memref <1024 x1024 xf32 >, %arg1: memref <1 xf32 >, %arg2: index ) -> f32 {
@@ -506,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0:
506
549
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
507
550
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
508
551
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
509
- // CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
510
- // CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
552
+ // CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[ %[[ARG3]], %[[ARG4]]]
553
+ // CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[ %[[ARG5]], %[[ARG6]]]
511
554
// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
512
555
513
556
// -----
514
557
515
- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1 )>
516
- // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1 )>
558
+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024 )>
559
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 )>
517
560
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
518
561
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
519
562
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression (%arg0: memref <1024 x1024 xf32 >, %arg1: memref <1 xf32 >, %arg2: index ) -> f32 {
@@ -535,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
535
578
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
536
579
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
537
580
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
538
- // CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
539
- // CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
581
+ // CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]]
582
+ // CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[ %[[ARG5]], %[[ARG6]]]
540
583
// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32>
541
584
542
585
// -----
543
586
544
- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)>
545
- // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1 )>
587
+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)>
588
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 )>
546
589
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
547
590
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
548
591
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index (%arg0: memref <1024 x1024 xf32 >, %arg1: memref <1 xf32 >, %arg2: index ) -> f32 {
@@ -565,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c
565
608
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
566
609
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
567
610
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
568
- // CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
569
- // CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
611
+ // CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[ %[[ARG3]]]
612
+ // CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[ %[[ARG5]], %[[ARG6]]]
570
613
// CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>
571
614
572
615
// -----
0 commit comments