@@ -710,24 +710,44 @@ func.func @fold_extract_transpose(
710
710
711
711
// -----
712
712
713
- // CHECK-LABEL: fold_extract_broadcast
713
+ // CHECK-LABEL: fold_extract_broadcast_same_type
714
714
// CHECK-SAME: %[[A:.*]]: f32
715
715
// CHECK: return %[[A]] : f32
716
- func.func @fold_extract_broadcast (%a : f32 ) -> f32 {
716
+ func.func @fold_extract_broadcast_same_type (%a : f32 ,
717
+ %idx0 : index ,
718
+ %idx1 : index ) -> f32 {
717
719
%b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
718
- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
720
+ // The indices don't batter for this folder, so we use mixed indices.
721
+ %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
719
722
return %r : f32
720
723
}
721
724
722
725
// -----
723
726
724
- // CHECK-LABEL: fold_extract_broadcast_0dvec
727
+ // CHECK-LABEL: fold_extract_broadcast_same_type_vec
728
+ // CHECK-SAME: %[[A:.*]]: vector<4xf32>
729
+ // CHECK: return %[[A]] : vector<4xf32>
730
+ func.func @fold_extract_broadcast_same_type_vec (%a : vector <4 xf32 >,
731
+ %idx0 : index )
732
+ -> vector <4 xf32 > {
733
+ %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
734
+ // The indices don't batter for this folder, so we use mixed indices.
735
+ %r = vector.extract %b [0 , %idx0 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
736
+ return %r : vector <4 xf32 >
737
+ }
738
+
739
+ // -----
740
+
741
+ // CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
725
742
// CHECK-SAME: %[[A:.*]]: vector<f32>
726
743
// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
727
744
// CHECK: return %[[B]] : f32
728
- func.func @fold_extract_broadcast_0dvec (%a : vector <f32 >) -> f32 {
745
+ func.func @fold_extract_broadcast_0dvec_and_scalar (%a : vector <f32 >,
746
+ %idx0 : index ,
747
+ %idx1 : index ) -> f32 {
729
748
%b = vector.broadcast %a : vector <f32 > to vector <1 x2 x4 xf32 >
730
- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
749
+ // The indices don't batter for this folder, so we use mixed indices.
750
+ %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
731
751
return %r : f32
732
752
}
733
753
@@ -747,57 +767,71 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
747
767
// CHECK-LABEL: fold_extract_splat
748
768
// CHECK-SAME: %[[A:.*]]: f32
749
769
// CHECK: return %[[A]] : f32
750
- func.func @fold_extract_splat (%a : f32 ) -> f32 {
770
+ func.func @fold_extract_splat (%a : f32 , %idx0 : index , %idx1 : index ) -> f32 {
751
771
%b = vector.splat %a : vector <1 x2 x4 xf32 >
752
- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
772
+ // The indices don't batter for this folder, so we use mixed indices.
773
+ %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
753
774
return %r : f32
754
775
}
755
776
756
777
// -----
757
778
758
- // CHECK-LABEL: fold_extract_broadcast_vector
779
+ // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
759
780
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
760
- // CHECK: return %[[A]] : vector<4xf32>
761
- func.func @fold_extract_broadcast_vector (%a : vector <4 xf32 >) -> vector <4 xf32 > {
781
+ // CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
782
+ // CHECK: return %[[R]] : f32
783
+ func.func @fold_extract_broadcast_dim1_broadcasting (%a : vector <4 xf32 >) -> f32 {
762
784
%b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
763
- %r = vector.extract %b [0 , 1 ] : vector < 4 x f32 > from vector <1 x2 x4 xf32 >
764
- return %r : vector < 4 x f32 >
785
+ %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
786
+ return %r : f32
765
787
}
766
788
767
789
// -----
768
790
769
- // CHECK-LABEL: fold_extract_broadcast
791
+ // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
770
792
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
771
- // CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
793
+ // CHECK-SAME: %[[IDX:.*]]: index
794
+ // CHECK: %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
795
+ // CHECK: %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
772
796
// CHECK: return %[[R]] : f32
773
- func.func @fold_extract_broadcast (%a : vector <4 xf32 >) -> f32 {
797
+ // This folder is not yet implemented. Check that this does not fold.
798
+ func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi (
799
+ %a : vector <4 xf32 >,
800
+ %idx : index ) -> f32 {
774
801
%b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
775
- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
802
+ %r = vector.extract %b [%idx , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
776
803
return %r : f32
777
804
}
778
805
779
806
// -----
780
807
781
- // CHECK-LABEL: fold_extract_broadcast
808
+ // CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
782
809
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
783
810
// CHECK: return %[[B]] : vector<4xf32>
784
- func.func @fold_extract_broadcast (%a : f32 ) -> vector <4 xf32 > {
811
+ func.func @canonicalize_extract_broadcast_to_higher_rank (%a : f32 ,
812
+ %idx0 : index )
813
+ -> vector <4 xf32 > {
785
814
%b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
786
- %r = vector.extract %b [0 , 1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
815
+ // The indices don't batter for this canonicalizer, so we use mixed indices.
816
+ %r = vector.extract %b [0 , %idx0 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
787
817
return %r : vector <4 xf32 >
788
818
}
789
819
790
820
// -----
791
821
792
- // CHECK-LABEL: fold_extract_broadcast
822
+ // CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
793
823
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
794
824
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
795
825
// CHECK: return %[[R]] : vector<8xf32>
796
- func.func @fold_extract_broadcast (%a : vector <1 xf32 >) -> vector <8 xf32 > {
826
+ func.func @canonicalize_extract_broadcast_to_equal_rank (%a : vector <1 xf32 >,
827
+ %idx0 : index )
828
+ -> vector <8 xf32 > {
797
829
%b = vector.broadcast %a : vector <1 xf32 > to vector <1 x8 xf32 >
798
- %r = vector.extract %b [0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
830
+ // The indices don't batter for this canonicalizer, so we use mixed indices.
831
+ %r = vector.extract %b [%idx0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
799
832
return %r : vector <8 xf32 >
800
833
}
834
+
801
835
// -----
802
836
803
837
// CHECK-LABEL: @fold_extract_shuffle
0 commit comments