@@ -67,6 +67,70 @@ func.func @create_mask_transpose_to_transposed_create_mask(
67
67
68
68
// -----
69
69
70
+ // CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast
71
+ // CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
72
+ func.func @transposed_unit_dim_shape_cast_to_shape_cast (%vec: vector <[4 ]xf32 >) -> vector <1 x[4 ]xf32 > {
73
+ // CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32>
74
+ // CHECK-NOT: vector.transpose
75
+ %0 = vector.shape_cast %vec : vector <[4 ]xf32 > to vector <[4 ]x1 xf32 >
76
+ // 0 -> 1 is a unit dim:
77
+ %1 = vector.transpose %0 , [1 , 0 ] : vector <[4 ]x1 xf32 > to vector <1 x[4 ]xf32 >
78
+ return %1 : vector <1 x[4 ]xf32 >
79
+ }
80
+
81
+ // -----
82
+
83
+ // CHECK-LABEL: transposed_multiple_unit_dim_shape_cast_to_shape_cast
84
+ // CHECK-SAME: %[[VEC:.*]]: vector<120xf32>
85
+ func.func @transposed_multiple_unit_dim_shape_cast_to_shape_cast (%vec: vector <120 xf32 >) -> vector <2 x1 x3 x4 x1 x5 xf32 > {
86
+ // CHECK: vector.shape_cast %[[VEC]] : vector<120xf32> to vector<2x1x3x4x1x5xf32>
87
+ // CHECK-NOT: vector.transpose
88
+ %0 = vector.shape_cast %vec : vector <120 xf32 > to vector <1 x2 x3 x4 x5 x1 xf32 >
89
+ // 0 -> 1 and 4 -> 5 are unit dims:
90
+ %1 = vector.transpose %0 , [1 , 0 , 2 , 3 , 5 , 4 ] : vector <1 x2 x3 x4 x5 x1 xf32 > to vector <2 x1 x3 x4 x1 x5 xf32 >
91
+ return %1 : vector <2 x1 x3 x4 x1 x5 xf32 >
92
+ }
93
+
94
+ // -----
95
+
96
+ // CHECK-LABEL: transposed_non_unit_dim_shape_cast_0
97
+ // CHECK-SAME: %[[VEC:.*]]: vector<120xf32>
98
+ func.func @transposed_non_unit_dim_shape_cast_0 (%vec: vector <120 xf32 >) -> vector <1 x3 x2 x4 x1 x5 xf32 > {
99
+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<120xf32> to vector<1x2x3x4x5x1xf32>
100
+ // CHECK-NEXT: vector.transpose %[[SHAPE_CAST]], [0, 2, 1, 3, 5, 4] : vector<1x2x3x4x5x1xf32> to vector<1x3x2x4x1x5xf32>
101
+ %0 = vector.shape_cast %vec : vector <120 xf32 > to vector <1 x2 x3 x4 x5 x1 xf32 >
102
+ // 1 -> 2 is a non-unit dim:
103
+ %1 = vector.transpose %0 , [0 , 2 , 1 , 3 , 5 , 4 ] : vector <1 x2 x3 x4 x5 x1 xf32 > to vector <1 x3 x2 x4 x1 x5 xf32 >
104
+ return %1 : vector <1 x3 x2 x4 x1 x5 xf32 >
105
+ }
106
+ // -----
107
+
108
+ // CHECK-LABEL: transposed_non_unit_dim_shape_cast_1
109
+ // CHECK-SAME: %[[VEC:.*]]: vector<1x256x256xf32>
110
+ func.func @transposed_non_unit_dim_shape_cast_1 (%vec: vector <1 x256 x256 xf32 >) -> vector <256 x256 xf32 > {
111
+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x256x256xf32> to vector<256x256xf32>
112
+ // CHECK-NEXT: vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<256x256xf32> to vector<256x256xf32>
113
+ %0 = vector.shape_cast %vec : vector <1 x256 x256 xf32 > to vector <256 x256 xf32 >
114
+ // 0 -> 1 is a non-unit dim:
115
+ %1 = vector.transpose %0 , [1 , 0 ] : vector <256 x256 xf32 > to vector <256 x256 xf32 >
116
+ return %1 : vector <256 x256 xf32 >
117
+ }
118
+
119
+ // -----
120
+
121
+ // CHECK-LABEL: transposed_non_unit_dim_shape_cast_2
122
+ // CHECK-SAME: %[[VEC:.*]]: vector<20xf32>
123
+ func.func @transposed_non_unit_dim_shape_cast_2 (%vec: vector <20 xf32 >) -> vector <2 x5 x2 x1 xf32 > {
124
+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<20xf32> to vector<2x1x2x5xf32>
125
+ // CHECK-NEXT: vector.transpose %[[SHAPE_CAST]], [0, 3, 2, 1] : vector<2x1x2x5xf32> to vector<2x5x2x1xf32>
126
+ %0 = vector.shape_cast %vec : vector <20 xf32 > to vector <2 x1 x2 x5 xf32 >
127
+ // 1 -> 3 transposes non-unit dims:
128
+ %1 = vector.transpose %0 , [0 , 3 , 2 , 1 ] : vector <2 x1 x2 x5 xf32 > to vector <2 x5 x2 x1 xf32 >
129
+ return %1 : vector <2 x5 x2 x1 xf32 >
130
+ }
131
+
132
+ // -----
133
+
70
134
// CHECK-LABEL: extract_from_create_mask
71
135
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
72
136
func.func @extract_from_create_mask (%dim0: index , %dim1: index ) -> vector <[4 ]x[4 ]xi1 > {
0 commit comments