@@ -868,14 +868,16 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
868
868
// -----
869
869
870
870
// CHECK-LABEL: fold_vector_transfer_masks
871
- func.func @fold_vector_transfer_masks (%A: memref <?x?xf32 >) -> (vector <4 x8 xf32 >) {
871
+ func.func @fold_vector_transfer_masks (%A: memref <?x?xf32 >) -> (vector <4 x8 xf32 >, vector < 4 x[ 4 ]x f32 > ) {
872
872
// CHECK: %[[C0:.+]] = arith.constant 0 : index
873
873
%c0 = arith.constant 0 : index
874
874
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
875
875
%f0 = arith.constant 0.0 : f32
876
876
877
877
%mask = vector.constant_mask [8 , 4 ] : vector <8 x4 xi1 >
878
878
879
+ %arith_all_true_mask = arith.constant dense <true > : vector <4 x[4 ]xi1 >
880
+
879
881
// CHECK: vector.transfer_read %{{.*}}, %[[F0]] {permutation_map
880
882
%1 = vector.transfer_read %A [%c0 , %c0 ], %f0 , %mask
881
883
{permutation_map = affine_map <(d0 , d1 ) -> (d1 , d0 )>} : memref <?x?xf32 >, vector <4 x8 xf32 >
@@ -884,8 +886,14 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>)
884
886
vector.transfer_write %1 , %A [%c0 , %c0 ], %mask
885
887
{permutation_map = affine_map <(d0 , d1 ) -> (d1 , d0 )>} : vector <4 x8 xf32 >, memref <?x?xf32 >
886
888
889
+ // CHECK: vector.transfer_read %{{.*}}, %[[F0]] :
890
+ %2 = vector.transfer_read %A [%c0 , %c0 ], %f0 , %arith_all_true_mask : memref <?x?xf32 >, vector <4 x[4 ]xf32 >
891
+
892
+ // CHECK: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] :
893
+ vector.transfer_write %2 , %A [%c0 , %c0 ], %arith_all_true_mask : vector <4 x[4 ]xf32 >, memref <?x?xf32 >
894
+
887
895
// CHECK: return
888
- return %1 : vector <4 x8 xf32 >
896
+ return %1 , %2 : vector <4 x8 xf32 >, vector < 4 x[ 4 ]x f32 >
889
897
}
890
898
891
899
// -----
0 commit comments