@@ -166,6 +166,28 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
166
166
167
167
// -----
168
168
169
+ // CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
170
+ // CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
171
+ // CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
172
+
173
+ // CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat(
174
+ // CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x8xi32>,
175
+ // CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x8xi32>,
176
+ // CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
177
+ // CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
178
+ // CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
179
+ // CHECK: %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
180
+ // CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
181
+ // CHECK: return %[[VAL_6]] : vector<1x8xi32>
182
+ // CHECK: }
183
+ func.func @cast_away_contraction_leading_one_dims_vec_mat (%lhs: vector <1 x1 x8 xi32 >,
184
+ %rhs: vector <1 x8 x8 xi32 >,
185
+ %acc: vector <1 x8 xi32 >) -> vector <1 x8 xi32 > {
186
+ %result = vector.contract {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >} %lhs , %rhs , %acc : vector <1 x1 x8 xi32 >, vector <1 x8 x8 xi32 > into vector <1 x8 xi32 >
187
+ return %result : vector <1 x8 xi32 >
188
+ }
189
+
190
+ // -----
169
191
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
170
192
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
171
193
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
0 commit comments