@@ -473,10 +473,10 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
473
473
func.func @fold_dynamic_subview_with_memref_load_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index , %sz0: index ) -> f32 {
474
474
%c0 = arith.constant 0 : index
475
475
%expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] output_shape [1 , 16 , %sz0 , 1 ] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
476
- %0 = memref.load %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
476
+ %0 = memref.load %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] { nontemporal = true } : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
477
477
return %0 : f32
478
478
}
479
- // CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
479
+ // CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
480
480
// CHECK-NEXT: return %[[VAL1]] : f32
481
481
482
482
// -----
@@ -487,11 +487,11 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
487
487
%c0 = arith.constant 0 : index
488
488
%c1f32 = arith.constant 1.0 : f32
489
489
%expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] output_shape [1 , 16 , %sz0 , 1 ] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
490
- memref.store %c1f32 , %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
490
+ memref.store %c1f32 , %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] { nontemporal = true } : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
491
491
return
492
492
}
493
493
// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
494
- // CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
494
+ // CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
495
495
// CHECK-NEXT: return
496
496
497
497
// -----
@@ -819,29 +819,29 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
819
819
820
820
// -----
821
821
822
- func.func @fold_vector_load (
822
+ func.func @fold_vector_load_subview (
823
823
%arg0 : memref <12 x32 xf32 >, %arg1 : index , %arg2 : index ) -> vector <12 x32 xf32 > {
824
824
%0 = memref.subview %arg0 [%arg1 , %arg2 ][1 , 1 ][1 , 1 ] : memref <12 x32 xf32 > to memref <f32 , strided <[], offset : ?>>
825
825
%1 = vector.load %0 [] : memref <f32 , strided <[], offset : ?>>, vector <12 x32 xf32 >
826
826
return %1 : vector <12 x32 xf32 >
827
827
}
828
828
829
- // CHECK: func @fold_vector_load
829
+ // CHECK: func @fold_vector_load_subview
830
830
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
831
831
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
832
832
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
833
833
// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>
834
834
835
835
// -----
836
836
837
- func.func @fold_vector_maskedload (
837
+ func.func @fold_vector_maskedload_subview (
838
838
%arg0 : memref <12 x32 xf32 >, %arg1 : index , %arg2 : index , %arg3: vector <32 xi1 >, %arg4: vector <32 xf32 >) -> vector <32 xf32 > {
839
839
%0 = memref.subview %arg0 [%arg1 , %arg2 ][1 , 1 ][1 , 1 ] : memref <12 x32 xf32 > to memref <f32 , strided <[], offset : ?>>
840
840
%1 = vector.maskedload %0 [], %arg3 , %arg4 : memref <f32 , strided <[], offset : ?>>, vector <32 xi1 >, vector <32 xf32 > into vector <32 xf32 >
841
841
return %1 : vector <32 xf32 >
842
842
}
843
843
844
- // CHECK: func @fold_vector_maskedload
844
+ // CHECK: func @fold_vector_maskedload_subview
845
845
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
846
846
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
847
847
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -851,14 +851,14 @@ func.func @fold_vector_maskedload(
851
851
852
852
// -----
853
853
854
- func.func @fold_vector_store (
854
+ func.func @fold_vector_store_subview (
855
855
%arg0 : memref <12 x32 xf32 >, %arg1 : index , %arg2 : index , %arg3: vector <2 x32 xf32 >) -> () {
856
856
%0 = memref.subview %arg0 [%arg1 , %arg2 ][1 , 1 ][1 , 1 ] : memref <12 x32 xf32 > to memref <f32 , strided <[], offset : ?>>
857
857
vector.store %arg3 , %0 [] : memref <f32 , strided <[], offset : ?>>, vector <2 x32 xf32 >
858
858
return
859
859
}
860
860
861
- // CHECK: func @fold_vector_store
861
+ // CHECK: func @fold_vector_store_subview
862
862
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
863
863
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
864
864
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -868,18 +868,166 @@ func.func @fold_vector_store(
868
868
869
869
// -----
870
870
871
- func.func @fold_vector_maskedstore (
871
+ func.func @fold_vector_maskedstore_subview (
872
872
%arg0 : memref <12 x32 xf32 >, %arg1 : index , %arg2 : index , %arg3: vector <32 xi1 >, %arg4: vector <32 xf32 >) -> () {
873
873
%0 = memref.subview %arg0 [%arg1 , %arg2 ][1 , 1 ][1 , 1 ] : memref <12 x32 xf32 > to memref <f32 , strided <[], offset : ?>>
874
874
vector.maskedstore %0 [], %arg3 , %arg4 : memref <f32 , strided <[], offset : ?>>, vector <32 xi1 >, vector <32 xf32 >
875
875
return
876
876
}
877
877
878
- // CHECK: func @fold_vector_maskedstore
878
+ // CHECK: func @fold_vector_maskedstore_subview
879
879
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
880
880
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
881
881
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
882
882
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
883
883
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
884
884
// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
885
885
// CHECK: return
886
+
887
+ // -----
888
+
889
+ func.func @fold_vector_load_expand_shape (
890
+ %arg0 : memref <32 xf32 >, %arg1 : index ) -> vector <8 xf32 > {
891
+ %c0 = arith.constant 0 : index
892
+ %0 = memref.expand_shape %arg0 [[0 , 1 ]] output_shape [4 , 8 ] : memref <32 xf32 > into memref <4 x8 xf32 >
893
+ %1 = vector.load %0 [%arg1 , %c0 ] {nontemporal = true } : memref <4 x8 xf32 >, vector <8 xf32 >
894
+ return %1 : vector <8 xf32 >
895
+ }
896
+
897
+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
898
+ // CHECK-LABEL: func @fold_vector_load_expand_shape
899
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
900
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
901
+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
902
+ // CHECK: vector.load %[[ARG0]][%[[IDX]]] {nontemporal = true}
903
+
904
+ // -----
905
+
906
+ func.func @fold_vector_maskedload_expand_shape (
907
+ %arg0 : memref <32 xf32 >, %arg1 : index , %arg3: vector <8 xi1 >, %arg4: vector <8 xf32 >) -> vector <8 xf32 > {
908
+ %c0 = arith.constant 0 : index
909
+ %0 = memref.expand_shape %arg0 [[0 , 1 ]] output_shape [4 , 8 ] : memref <32 xf32 > into memref <4 x8 xf32 >
910
+ %1 = vector.maskedload %0 [%arg1 , %c0 ], %arg3 , %arg4 : memref <4 x8 xf32 >, vector <8 xi1 >, vector <8 xf32 > into vector <8 xf32 >
911
+ return %1 : vector <8 xf32 >
912
+ }
913
+
914
+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
915
+ // CHECK-LABEL: func @fold_vector_maskedload_expand_shape
916
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
917
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
918
+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
919
+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
920
+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
921
+ // CHECK: vector.maskedload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
922
+
923
+ // -----
924
+
925
+ func.func @fold_vector_store_expand_shape (
926
+ %arg0 : memref <32 xf32 >, %arg1 : index , %val : vector <8 xf32 >) {
927
+ %c0 = arith.constant 0 : index
928
+ %0 = memref.expand_shape %arg0 [[0 , 1 ]] output_shape [4 , 8 ] : memref <32 xf32 > into memref <4 x8 xf32 >
929
+ vector.store %val , %0 [%arg1 , %c0 ] {nontemporal = true } : memref <4 x8 xf32 >, vector <8 xf32 >
930
+ return
931
+ }
932
+
933
+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
934
+ // CHECK-LABEL: func @fold_vector_store_expand_shape
935
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
936
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
937
+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
938
+ // CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]]] {nontemporal = true}
939
+
940
+ // -----
941
+
942
+ func.func @fold_vector_maskedstore_expand_shape (
943
+ %arg0 : memref <32 xf32 >, %arg1 : index , %arg3: vector <8 xi1 >, %arg4: vector <8 xf32 >) {
944
+ %c0 = arith.constant 0 : index
945
+ %0 = memref.expand_shape %arg0 [[0 , 1 ]] output_shape [4 , 8 ] : memref <32 xf32 > into memref <4 x8 xf32 >
946
+ vector.maskedstore %0 [%arg1 , %c0 ], %arg3 , %arg4 : memref <4 x8 xf32 >, vector <8 xi1 >, vector <8 xf32 >
947
+ return
948
+ }
949
+
950
+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
951
+ // CHECK-LABEL: func @fold_vector_maskedstore_expand_shape
952
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
953
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
954
+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
955
+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
956
+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
957
+ // CHECK: vector.maskedstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
958
+
959
+ // -----
960
+
961
+ func.func @fold_vector_load_collapse_shape (
962
+ %arg0 : memref <4 x8 xf32 >, %arg1 : index ) -> vector <8 xf32 > {
963
+ %0 = memref.collapse_shape %arg0 [[0 , 1 ]] : memref <4 x8 xf32 > into memref <32 xf32 >
964
+ %1 = vector.load %0 [%arg1 ] {nontemporal = true } : memref <32 xf32 >, vector <8 xf32 >
965
+ return %1 : vector <8 xf32 >
966
+ }
967
+
968
+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
969
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
970
+ // CHECK-LABEL: func @fold_vector_load_collapse_shape
971
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
972
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
973
+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
974
+ // CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
975
+ // CHECK: vector.load %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
976
+
977
+ // -----
978
+
979
+ func.func @fold_vector_maskedload_collapse_shape (
980
+ %arg0 : memref <4 x8 xf32 >, %arg1 : index , %arg3: vector <8 xi1 >, %arg4: vector <8 xf32 >) -> vector <8 xf32 > {
981
+ %0 = memref.collapse_shape %arg0 [[0 , 1 ]] : memref <4 x8 xf32 > into memref <32 xf32 >
982
+ %1 = vector.maskedload %0 [%arg1 ], %arg3 , %arg4 : memref <32 xf32 >, vector <8 xi1 >, vector <8 xf32 > into vector <8 xf32 >
983
+ return %1 : vector <8 xf32 >
984
+ }
985
+
986
+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
987
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
988
+ // CHECK-LABEL: func @fold_vector_maskedload_collapse_shape
989
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
990
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
991
+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
992
+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
993
+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
994
+ // CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
995
+ // CHECK: vector.maskedload %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
996
+
997
+ // -----
998
+
999
+ func.func @fold_vector_store_collapse_shape (
1000
+ %arg0 : memref <4 x8 xf32 >, %arg1 : index , %val : vector <8 xf32 >) {
1001
+ %0 = memref.collapse_shape %arg0 [[0 , 1 ]] : memref <4 x8 xf32 > into memref <32 xf32 >
1002
+ vector.store %val , %0 [%arg1 ] {nontemporal = true } : memref <32 xf32 >, vector <8 xf32 >
1003
+ return
1004
+ }
1005
+
1006
+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
1007
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
1008
+ // CHECK-LABEL: func @fold_vector_store_collapse_shape
1009
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
1010
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
1011
+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
1012
+ // CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
1013
+ // CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
1014
+
1015
+ // -----
1016
+
1017
+ func.func @fold_vector_maskedstore_collapse_shape (
1018
+ %arg0 : memref <4 x8 xf32 >, %arg1 : index , %arg3: vector <8 xi1 >, %arg4: vector <8 xf32 >) {
1019
+ %0 = memref.collapse_shape %arg0 [[0 , 1 ]] : memref <4 x8 xf32 > into memref <32 xf32 >
1020
+ vector.maskedstore %0 [%arg1 ], %arg3 , %arg4 : memref <32 xf32 >, vector <8 xi1 >, vector <8 xf32 >
1021
+ return
1022
+ }
1023
+
1024
+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
1025
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
1026
+ // CHECK-LABEL: func @fold_vector_maskedstore_collapse_shape
1027
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
1028
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
1029
+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
1030
+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
1031
+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
1032
+ // CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
1033
+ // CHECK: vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
0 commit comments