@@ -719,7 +719,7 @@ func.func @fold_extract_transpose(
719
719
// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
720
720
// CHECK-SAME: %[[A:.*]]: f32
721
721
// CHECK: return %[[A]] : f32
722
- func.func @fold_extract_broadcast_same_input_output_scalar (%a : f32 ,
722
+ func.func @fold_extract_broadcast_same_input_output_scalar (%a : f32 ,
723
723
%idx0 : index , %idx1 : index , %idx2 : index ) -> f32 {
724
724
%b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
725
725
%r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
@@ -731,7 +731,7 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
731
731
// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
732
732
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
733
733
// CHECK: return %[[A]] : vector<4xf32>
734
- func.func @fold_extract_broadcast_same_input_output_vec (%a : vector <4 xf32 >,
734
+ func.func @fold_extract_broadcast_same_input_output_vec (%a : vector <4 xf32 >,
735
735
%idx0 : index , %idx1 : index ) -> vector <4 xf32 > {
736
736
%b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
737
737
%r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
@@ -744,7 +744,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
744
744
// CHECK-SAME: %[[A:.*]]: vector<f32>
745
745
// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
746
746
// CHECK: return %[[B]] : f32
747
- func.func @fold_extract_broadcast_0dvec_input_scalar_output (%a : vector <f32 >,
747
+ func.func @fold_extract_broadcast_0dvec_input_scalar_output (%a : vector <f32 >,
748
748
%idx0 : index , %idx1 : index , %idx2: index ) -> f32 {
749
749
%b = vector.broadcast %a : vector <f32 > to vector <1 x2 x4 xf32 >
750
750
%r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
@@ -780,7 +780,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
780
780
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
781
781
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
782
782
// CHECK: return %[[R]] : f32
783
- func.func @fold_extract_broadcast_dim1_broadcasting (%a : vector <2 x1 xf32 >,
783
+ func.func @fold_extract_broadcast_dim1_broadcasting (%a : vector <2 x1 xf32 >,
784
784
%idx : index , %idx1 : index , %idx2 : index ) -> f32 {
785
785
%b = vector.broadcast %a : vector <2 x1 xf32 > to vector <1 x2 x4 xf32 >
786
786
%r = vector.extract %b [%idx , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
@@ -795,7 +795,7 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
795
795
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
796
796
// CHECK: return %[[B]] : vector<4xf32>
797
797
// rank(extract_output) < rank(broadcast_input)
798
- func.func @fold_extract_broadcast_to_lower_rank (%a : vector <2 x4 xf32 >,
798
+ func.func @fold_extract_broadcast_to_lower_rank (%a : vector <2 x4 xf32 >,
799
799
%idx0 : index , %idx1 : index ) -> vector <4 xf32 > {
800
800
%b = vector.broadcast %a : vector <2 x4 xf32 > to vector <1 x2 x4 xf32 >
801
801
%r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
@@ -808,7 +808,7 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
808
808
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
809
809
// CHECK: return %[[B]] : vector<4xf32>
810
810
// rank(extract_output) > rank(broadcast_input)
811
- func.func @fold_extract_broadcast_to_higher_rank (%a : f32 , %idx0 : index , %idx1 : index )
811
+ func.func @fold_extract_broadcast_to_higher_rank (%a : f32 , %idx0 : index , %idx1 : index )
812
812
-> vector <4 xf32 > {
813
813
%b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
814
814
%r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
@@ -822,7 +822,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
822
822
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
823
823
// CHECK: return %[[R]] : vector<8xf32>
824
824
// rank(extract_output) == rank(broadcast_input)
825
- func.func @fold_extract_broadcast_to_equal_rank (%a : vector <1 xf32 >, %idx0 : index )
825
+ func.func @fold_extract_broadcast_to_equal_rank (%a : vector <1 xf32 >, %idx0 : index )
826
826
-> vector <8 xf32 > {
827
827
%b = vector.broadcast %a : vector <1 xf32 > to vector <1 x8 xf32 >
828
828
%r = vector.extract %b [%idx0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
@@ -1169,7 +1169,7 @@ func.func @broadcast_poison() -> vector<4x6xi8> {
1169
1169
return %broadcast : vector <4 x6 xi8 >
1170
1170
}
1171
1171
1172
- // -----
1172
+ // -----
1173
1173
1174
1174
// CHECK-LABEL: broadcast_splat_constant
1175
1175
// CHECK: %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
@@ -2756,6 +2756,19 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
2756
2756
2757
2757
// -----
2758
2758
2759
+ // CHECK-LABEL: func @empty_vector_mask_with_passthru
2760
+ // CHECK-SAME: %[[IN:.*]]: vector<8xf32>, %[[MASK:.*]]: vector<8xi1>, %[[PASSTHRU:.*]]: vector<8xf32>
2761
+ func.func @empty_vector_mask_with_passthru (%a : vector <8 xf32 >, %mask : vector <8 xi1 >,
2762
+ %passthru : vector <8 xf32 >) -> vector <8 xf32 > {
2763
+ // CHECK-NOT: vector.mask
2764
+ // CHECK: %[[SEL:.*]] = arith.select %[[MASK]], %[[IN]], %[[PASSTHRU]] : vector<8xi1>, vector<8xf32>
2765
+ // CHECK: return %[[SEL]] : vector<8xf32>
2766
+ %0 = vector.mask %mask , %passthru { vector.yield %a : vector <8 xf32 > } : vector <8 xi1 > -> vector <8 xf32 >
2767
+ return %0 : vector <8 xf32 >
2768
+ }
2769
+
2770
+ // -----
2771
+
2759
2772
// CHECK-LABEL: func @all_true_vector_mask
2760
2773
// CHECK-SAME: %[[IN:.*]]: tensor<3x4xf32>
2761
2774
func.func @all_true_vector_mask (%ta : tensor <3 x4 xf32 >) -> vector <3 x4 xf32 > {
0 commit comments