@@ -74,6 +74,17 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8
74
74
return %0 : vector <1 x1 x8 x8 xf32 >
75
75
}
76
76
77
+ /// Scalable dim should not be unrolled.
78
+
79
+ // CHECK-LABEL: func @transpose23_scalable
80
+ // CHECK-NOT: vector.extract
81
+ // CHECK-NOT: vector.insert
82
+ // CHECK: vector.transpose
83
+ func.func @transpose23_scalable (%arg0: vector <2 x[3 ]xf32 >) -> vector <[3 ]x2 xf32 > {
84
+ %0 = vector.transpose %arg0 , [1 , 0 ] : vector <2 x[3 ]xf32 > to vector <[3 ]x2 xf32 >
85
+ return %0 : vector <[3 ]x2 xf32 >
86
+ }
87
+
77
88
module attributes {transform.with_named_sequence } {
78
89
transform.named_sequence @__transform_main (%func_op: !transform.op <" func.func" > {transform.readonly }) {
79
90
transform.apply_patterns to %func_op {
@@ -778,3 +789,63 @@ module attributes {transform.with_named_sequence} {
778
789
transform.yield
779
790
}
780
791
}
792
+
793
+ // -----
794
+
795
+ /// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
796
+
797
+ // CHECK-LABEL: func @transpose10_4x1xf32
798
+ func.func @transpose10_4x1xf32 (%arg0: vector <4 x1 xf32 >) -> vector <1 x4 xf32 > {
799
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
800
+ %0 = vector.transpose %arg0 , [1 , 0 ] : vector <4 x1 xf32 > to vector <1 x4 xf32 >
801
+ return %0 : vector <1 x4 xf32 >
802
+ }
803
+
804
+ // CHECK-LABEL: func @transpose10_nx4x1xf32
805
+ func.func @transpose10_nx4x1xf32 (%arg0: vector <[4 ]x1 xf32 >) -> vector <1 x[4 ]xf32 > {
806
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
807
+ %0 = vector.transpose %arg0 , [1 , 0 ] : vector <[4 ]x1 xf32 > to vector <1 x[4 ]xf32 >
808
+ return %0 : vector <1 x[4 ]xf32 >
809
+ }
810
+
811
+ // CHECK-LABEL: func @transpose10_1x4xf32
812
+ func.func @transpose10_1x4xf32 (%arg0: vector <1 x4 xf32 >) -> vector <4 x1 xf32 > {
813
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
814
+ %0 = vector.transpose %arg0 , [1 , 0 ] : vector <1 x4 xf32 > to vector <4 x1 xf32 >
815
+ return %0 : vector <4 x1 xf32 >
816
+ }
817
+
818
+ // CHECK-LABEL: func @transpose10_1xnx4xf32
819
+ func.func @transpose10_1xnx4xf32 (%arg0: vector <1 x[4 ]xf32 >) -> vector <[4 ]x1 xf32 > {
820
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
821
+ %0 = vector.transpose %arg0 , [1 , 0 ] : vector <1 x[4 ]xf32 > to vector <[4 ]x1 xf32 >
822
+ return %0 : vector <[4 ]x1 xf32 >
823
+ }
824
+
825
+ /// Scalable unit dim should not be lowered to shape_cast.
826
+
827
+ // CHECK-LABEL: func @transpose10_4xnx1xf32
828
+ func.func @transpose10_4xnx1xf32 (%arg0: vector <4 x[1 ]xf32 >) -> vector <[1 ]x4 xf32 > {
829
+ // CHECK-NOT: vector.shape_cast
830
+ // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
831
+ %0 = vector.transpose %arg0 , [1 , 0 ] : vector <4 x[1 ]xf32 > to vector <[1 ]x4 xf32 >
832
+ return %0 : vector <[1 ]x4 xf32 >
833
+ }
834
+
835
+ // CHECK-LABEL: func @transpose10_nx4xnx1xf32
836
+ func.func @transpose10_nx4xnx1xf32 (%arg0: vector <4 x[1 ]xf32 >) -> vector <[1 ]x4 xf32 > {
837
+ // CHECK-NOT: vector.shape_cast
838
+ // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
839
+ %0 = vector.transpose %arg0 , [1 , 0 ] : vector <4 x[1 ]xf32 > to vector <[1 ]x4 xf32 >
840
+
841
+ return %0 : vector <[1 ]x4 xf32 >
842
+ }
843
+
844
+ module attributes {transform.with_named_sequence } {
845
+ transform.named_sequence @__transform_main (%func_op: !transform.op <" func.func" > {transform.readonly }) {
846
+ transform.apply_patterns to %func_op {
847
+ transform.apply_patterns.vector.lower_transpose
848
+ } : !transform.op <" func.func" >
849
+ transform.yield
850
+ }
851
+ }
0 commit comments