@@ -468,16 +468,67 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
468
468
469
469
// -----
470
470
471
- // CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
472
- func.func @fold_dynamic_subview_with_memref_load_store_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index ) -> f32 {
471
+ // CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
472
+ // CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
473
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> f32
474
+ func.func @fold_dynamic_subview_with_memref_load_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index ) -> f32 {
473
475
%c0 = arith.constant 0 : index
474
476
%expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
475
477
%0 = memref.load %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
476
478
return %0 : f32
477
479
}
478
- // CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape {{.+}} : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
479
- // CHECK: %[[LOAD:.+]] = memref.load %[[EXPAND_SHAPE]]
480
- // CHECK: return %[[LOAD]]
480
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
481
+ // CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
482
+ // CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
483
+ // CHECK-NEXT: return %[[VAL1]] : f32
484
+
485
+ // -----
486
+
487
+ // CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
488
+ // CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
489
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
490
+ func.func @fold_dynamic_subview_with_memref_store_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index ) {
491
+ %c0 = arith.constant 0 : index
492
+ %c1f32 = arith.constant 1.0 : f32
493
+ %expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
494
+ memref.store %c1f32 , %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
495
+ return
496
+ }
497
+ // CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
498
+ // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
499
+ // CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
500
+ // CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
501
+ // CHECK-NEXT: return
502
+
503
+ // -----
504
+
505
+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
506
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
507
+ // CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
508
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
509
+ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim (%alloc: memref <2048 x16 xf32 >, %c10: index , %c5: index , %c0: index ) {
510
+ %subview = memref.subview %alloc [%c5 , 0 ] [%c10 , 16 ] [1 , 1 ] : memref <2048 x16 xf32 > to memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>>
511
+ %expand_shape = memref.expand_shape %subview [[0 ], [1 , 2 , 3 ]] : memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>> into memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
512
+ %dim = memref.dim %expand_shape , %c0 : memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
513
+
514
+ affine.for %arg6 = 0 to %dim step 64 {
515
+ affine.for %arg7 = 0 to 16 step 16 {
516
+ %dummy_load = affine.load %expand_shape [%arg6 , 0 , %arg7 , %arg7 ] : memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
517
+ affine.store %dummy_load , %subview [%arg6 , %arg7 ] : memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>>
518
+ }
519
+ }
520
+ return
521
+ }
522
+ // CHECK-NEXT: memref.subview
523
+ // CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
524
+ // CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
525
+ // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
526
+ // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
527
+ // CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
528
+ // CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
529
+ // CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
530
+ // CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
531
+ // CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
481
532
482
533
// -----
483
534
0 commit comments