Skip to content

Commit 57e4360

Browse files
authored
[mlir][memref] Add memref alias folders for expand/collapse_shape for vector load/store (#95223)
This patch adds adds patterns to fold memref alias for expand_shape/collapse_shape feeding into vector.load/vector.store and vector.maskedload/vector.maskedstore
1 parent 3e3b7c7 commit 57e4360

File tree

2 files changed

+239
-23
lines changed

2 files changed

+239
-23
lines changed

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -518,10 +518,25 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
518518
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
519519
return failure();
520520
llvm::TypeSwitch<Operation *, void>(loadOp)
521-
.Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
522-
rewriter.replaceOpWithNewOp<decltype(op)>(
521+
.Case([&](affine::AffineLoadOp op) {
522+
rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
523523
loadOp, expandShapeOp.getViewSource(), sourceIndices);
524524
})
525+
.Case([&](memref::LoadOp op) {
526+
rewriter.replaceOpWithNewOp<memref::LoadOp>(
527+
loadOp, expandShapeOp.getViewSource(), sourceIndices,
528+
op.getNontemporal());
529+
})
530+
.Case([&](vector::LoadOp op) {
531+
rewriter.replaceOpWithNewOp<vector::LoadOp>(
532+
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
533+
op.getNontemporal());
534+
})
535+
.Case([&](vector::MaskedLoadOp op) {
536+
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
537+
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
538+
op.getMask(), op.getPassThru());
539+
})
525540
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
526541
return success();
527542
}
@@ -551,10 +566,25 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
551566
loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
552567
return failure();
553568
llvm::TypeSwitch<Operation *, void>(loadOp)
554-
.Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
555-
rewriter.replaceOpWithNewOp<decltype(op)>(
569+
.Case([&](affine::AffineLoadOp op) {
570+
rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
556571
loadOp, collapseShapeOp.getViewSource(), sourceIndices);
557572
})
573+
.Case([&](memref::LoadOp op) {
574+
rewriter.replaceOpWithNewOp<memref::LoadOp>(
575+
loadOp, collapseShapeOp.getViewSource(), sourceIndices,
576+
op.getNontemporal());
577+
})
578+
.Case([&](vector::LoadOp op) {
579+
rewriter.replaceOpWithNewOp<vector::LoadOp>(
580+
op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
581+
op.getNontemporal());
582+
})
583+
.Case([&](vector::MaskedLoadOp op) {
584+
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
585+
op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
586+
op.getMask(), op.getPassThru());
587+
})
558588
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
559589
return success();
560590
}
@@ -651,10 +681,25 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
651681
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
652682
return failure();
653683
llvm::TypeSwitch<Operation *, void>(storeOp)
654-
.Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
655-
rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
656-
expandShapeOp.getViewSource(),
657-
sourceIndices);
684+
.Case([&](affine::AffineStoreOp op) {
685+
rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
686+
storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
687+
sourceIndices);
688+
})
689+
.Case([&](memref::StoreOp op) {
690+
rewriter.replaceOpWithNewOp<memref::StoreOp>(
691+
storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
692+
sourceIndices, op.getNontemporal());
693+
})
694+
.Case([&](vector::StoreOp op) {
695+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
696+
op, op.getValueToStore(), expandShapeOp.getViewSource(),
697+
sourceIndices, op.getNontemporal());
698+
})
699+
.Case([&](vector::MaskedStoreOp op) {
700+
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
701+
op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
702+
op.getValueToStore());
658703
})
659704
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
660705
return success();
@@ -685,11 +730,26 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
685730
storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
686731
return failure();
687732
llvm::TypeSwitch<Operation *, void>(storeOp)
688-
.Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
689-
rewriter.replaceOpWithNewOp<decltype(op)>(
690-
storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
733+
.Case([&](affine::AffineStoreOp op) {
734+
rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
735+
storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
691736
sourceIndices);
692737
})
738+
.Case([&](memref::StoreOp op) {
739+
rewriter.replaceOpWithNewOp<memref::StoreOp>(
740+
storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
741+
sourceIndices, op.getNontemporal());
742+
})
743+
.Case([&](vector::StoreOp op) {
744+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
745+
op, op.getValueToStore(), collapseShapeOp.getViewSource(),
746+
sourceIndices, op.getNontemporal());
747+
})
748+
.Case([&](vector::MaskedStoreOp op) {
749+
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
750+
op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
751+
op.getValueToStore());
752+
})
693753
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
694754
return success();
695755
}
@@ -763,12 +823,20 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
763823
StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
764824
LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
765825
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
826+
LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
827+
LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
766828
StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
767829
StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
830+
StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
831+
StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
768832
LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
769833
LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
834+
LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
835+
LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
770836
StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
771837
StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
838+
StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
839+
StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
772840
SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
773841
patterns.getContext());
774842
}

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 160 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,10 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
473473
func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
474474
%c0 = arith.constant 0 : index
475475
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
476-
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
476+
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
477477
return %0 : f32
478478
}
479-
// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
479+
// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
480480
// CHECK-NEXT: return %[[VAL1]] : f32
481481

482482
// -----
@@ -487,11 +487,11 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
487487
%c0 = arith.constant 0 : index
488488
%c1f32 = arith.constant 1.0 : f32
489489
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
490-
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
490+
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
491491
return
492492
}
493493
// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
494-
// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
494+
// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
495495
// CHECK-NEXT: return
496496

497497
// -----
@@ -819,29 +819,29 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
819819

820820
// -----
821821

822-
func.func @fold_vector_load(
822+
func.func @fold_vector_load_subview(
823823
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
824824
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
825825
%1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
826826
return %1 : vector<12x32xf32>
827827
}
828828

829-
// CHECK: func @fold_vector_load
829+
// CHECK: func @fold_vector_load_subview
830830
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
831831
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
832832
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
833833
// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>
834834

835835
// -----
836836

837-
func.func @fold_vector_maskedload(
837+
func.func @fold_vector_maskedload_subview(
838838
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
839839
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
840840
%1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
841841
return %1 : vector<32xf32>
842842
}
843843

844-
// CHECK: func @fold_vector_maskedload
844+
// CHECK: func @fold_vector_maskedload_subview
845845
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
846846
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
847847
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -851,14 +851,14 @@ func.func @fold_vector_maskedload(
851851

852852
// -----
853853

854-
func.func @fold_vector_store(
854+
func.func @fold_vector_store_subview(
855855
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
856856
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
857857
vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
858858
return
859859
}
860860

861-
// CHECK: func @fold_vector_store
861+
// CHECK: func @fold_vector_store_subview
862862
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
863863
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
864864
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -868,18 +868,166 @@ func.func @fold_vector_store(
868868

869869
// -----
870870

871-
func.func @fold_vector_maskedstore(
871+
func.func @fold_vector_maskedstore_subview(
872872
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
873873
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
874874
vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
875875
return
876876
}
877877

878-
// CHECK: func @fold_vector_maskedstore
878+
// CHECK: func @fold_vector_maskedstore_subview
879879
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
880880
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
881881
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
882882
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
883883
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
884884
// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
885885
// CHECK: return
886+
887+
// -----
888+
889+
func.func @fold_vector_load_expand_shape(
890+
%arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
891+
%c0 = arith.constant 0 : index
892+
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
893+
%1 = vector.load %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32>
894+
return %1 : vector<8xf32>
895+
}
896+
897+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
898+
// CHECK-LABEL: func @fold_vector_load_expand_shape
899+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
900+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
901+
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
902+
// CHECK: vector.load %[[ARG0]][%[[IDX]]] {nontemporal = true}
903+
904+
// -----
905+
906+
func.func @fold_vector_maskedload_expand_shape(
907+
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
908+
%c0 = arith.constant 0 : index
909+
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
910+
%1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
911+
return %1 : vector<8xf32>
912+
}
913+
914+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
915+
// CHECK-LABEL: func @fold_vector_maskedload_expand_shape
916+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
917+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
918+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
919+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
920+
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
921+
// CHECK: vector.maskedload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
922+
923+
// -----
924+
925+
func.func @fold_vector_store_expand_shape(
926+
%arg0 : memref<32xf32>, %arg1 : index, %val : vector<8xf32>) {
927+
%c0 = arith.constant 0 : index
928+
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
929+
vector.store %val, %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32>
930+
return
931+
}
932+
933+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
934+
// CHECK-LABEL: func @fold_vector_store_expand_shape
935+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
936+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
937+
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
938+
// CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]]] {nontemporal = true}
939+
940+
// -----
941+
942+
func.func @fold_vector_maskedstore_expand_shape(
943+
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
944+
%c0 = arith.constant 0 : index
945+
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
946+
vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
947+
return
948+
}
949+
950+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
951+
// CHECK-LABEL: func @fold_vector_maskedstore_expand_shape
952+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
953+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
954+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
955+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
956+
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
957+
// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
958+
959+
// -----
960+
961+
func.func @fold_vector_load_collapse_shape(
962+
%arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
963+
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
964+
%1 = vector.load %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32>
965+
return %1 : vector<8xf32>
966+
}
967+
968+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
969+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
970+
// CHECK-LABEL: func @fold_vector_load_collapse_shape
971+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
972+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
973+
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
974+
// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
975+
// CHECK: vector.load %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
976+
977+
// -----
978+
979+
func.func @fold_vector_maskedload_collapse_shape(
980+
%arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
981+
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
982+
%1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
983+
return %1 : vector<8xf32>
984+
}
985+
986+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
987+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
988+
// CHECK-LABEL: func @fold_vector_maskedload_collapse_shape
989+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
990+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
991+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
992+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
993+
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
994+
// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
995+
// CHECK: vector.maskedload %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
996+
997+
// -----
998+
999+
func.func @fold_vector_store_collapse_shape(
1000+
%arg0 : memref<4x8xf32>, %arg1 : index, %val : vector<8xf32>) {
1001+
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
1002+
vector.store %val, %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32>
1003+
return
1004+
}
1005+
1006+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
1007+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
1008+
// CHECK-LABEL: func @fold_vector_store_collapse_shape
1009+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
1010+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
1011+
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
1012+
// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
1013+
// CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
1014+
1015+
// -----
1016+
1017+
func.func @fold_vector_maskedstore_collapse_shape(
1018+
%arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
1019+
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
1020+
vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
1021+
return
1022+
}
1023+
1024+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
1025+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
1026+
// CHECK-LABEL: func @fold_vector_maskedstore_collapse_shape
1027+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
1028+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
1029+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
1030+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
1031+
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
1032+
// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
1033+
// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]

0 commit comments

Comments
 (0)