@@ -170,17 +170,18 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
170
170
// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
171
171
// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
172
172
173
- // CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat (
174
- // CHECK-SAME: %[[VAL_0 :.*]]: vector<1x1x8xi32>,
175
- // CHECK-SAME: %[[VAL_1 :.*]]: vector<1x8x8xi32>,
176
- // CHECK-SAME: %[[VAL_2 :.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
177
- // CHECK: %[[VAL_3 :.*]] = vector.extract %[[VAL_0 ]][0] : vector<1x8xi32> from vector<1x1x8xi32>
178
- // CHECK: %[[VAL_4 :.*]] = vector.extract %[[VAL_2 ]][0] : vector<8xi32> from vector<1x8xi32>
179
- // CHECK: %[[VAL_5 :.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3 ]], %[[VAL_1 ]], %[[VAL_4 ]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
180
- // CHECK: %[[VAL_6 :.*]] = vector.broadcast %[[VAL_5 ]] : vector<8xi32> to vector<1x8xi32>
181
- // CHECK: return %[[VAL_6 ]] : vector<1x8xi32>
173
+ // CHECK-LABEL: func.func @cast_away_contraction_does_not_transpose_leading_unit_dims (
174
+ // CHECK-SAME: %[[LHS :.*]]: vector<1x1x8xi32>,
175
+ // CHECK-SAME: %[[RHS :.*]]: vector<1x8x8xi32>,
176
+ // CHECK-SAME: %[[ACC :.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
177
+ // CHECK: %[[EXT_LHS :.*]] = vector.extract %[[LHS ]][0] : vector<1x8xi32> from vector<1x1x8xi32>
178
+ // CHECK: %[[EXT_ACC :.*]] = vector.extract %[[ACC ]][0] : vector<8xi32> from vector<1x8xi32>
179
+ // CHECK: %[[RES :.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[EXT_LHS ]], %[[RHS ]], %[[EXT_ACC ]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
180
+ // CHECK: %[[BROADCAST_RES :.*]] = vector.broadcast %[[RES ]] : vector<8xi32> to vector<1x8xi32>
181
+ // CHECK: return %[[BROADCAST_RES ]] : vector<1x8xi32>
182
182
// CHECK: }
183
- func.func @cast_away_contraction_leading_one_dims_vec_mat (%lhs: vector <1 x1 x8 xi32 >,
183
+ // CHECK-NOT vector.transpose
184
+ func.func @cast_away_contraction_does_not_transpose_leading_unit_dims (%lhs: vector <1 x1 x8 xi32 >,
184
185
%rhs: vector <1 x8 x8 xi32 >,
185
186
%acc: vector <1 x8 xi32 >) -> vector <1 x8 xi32 > {
186
187
%result = vector.contract {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >} %lhs , %rhs , %acc : vector <1 x1 x8 xi32 >, vector <1 x8 x8 xi32 > into vector <1 x8 xi32 >
0 commit comments