@@ -235,6 +235,23 @@ func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>,
235
235
// ============================================================================
236
236
// Matvec 2 (plain + masked + scalable)
237
237
// ============================================================================
238
+ // CHECK-LABEL: func @matvec_km_k_m
239
+ // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
240
+ // CHECK-SAME: %[[X:.*1]]: vector<2xf32>
241
+ // CHECK-SAME: %[[B:.*2]]: vector<2xf32>
242
+ // CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
243
+ // CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
244
+ // CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
245
+ // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
246
+ // CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
247
+ // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
248
+ func.func @matvec_km_k_m (%A: vector <2 x2 xf32 >,
249
+ %x: vector <2 xf32 >,
250
+ %b: vector <2 xf32 >) -> vector <2 xf32 > {
251
+ %0 = vector.contract #matvec_trait_2 %A , %x , %b : vector <2 x2 xf32 >, vector <2 xf32 > into vector <2 xf32 >
252
+ return %0 : vector <2 xf32 >
253
+ }
254
+
238
255
// CHECK-LABEL: @masked_matvec_km_k_m
239
256
// CHECK-SAME: %[[A:.+]]: vector<2x4xf32>
240
257
// CHECK-SAME: %[[X:.+]]: vector<2xf32>
@@ -273,26 +290,27 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
273
290
return %res : vector <[4 ]xf32 >
274
291
}
275
292
276
- // CHECK-LABEL: func @matvec_km_k_m
293
+ // ============================================================================
294
+ // Matvec 3 (plain + masked + scalable)
295
+ // ============================================================================
296
+ // CHECK-LABEL: func @matvec_k_mk_m
277
297
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
278
298
// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
279
299
// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
280
- // CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
281
- // CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
282
- // CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
283
- // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
284
- // CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
285
- // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
286
- func.func @matvec_km_k_m (%A: vector <2 x2 xf32 >,
300
+ // CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
301
+ // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
302
+ // CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
303
+ // CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
304
+ // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
305
+ // CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
306
+ // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
307
+ func.func @matvec_k_mk_m (%A: vector <2 x2 xf32 >,
287
308
%x: vector <2 xf32 >,
288
309
%b: vector <2 xf32 >) -> vector <2 xf32 > {
289
- %0 = vector.contract #matvec_trait_2 %A , %x , %b : vector <2 x 2 x f32 >, vector <2 x f32 > into vector <2 xf32 >
310
+ %0 = vector.contract #matvec_trait_3 %x , %A , %b : vector <2 x f32 >, vector <2 x 2 x f32 > into vector <2 xf32 >
290
311
return %0 : vector <2 xf32 >
291
312
}
292
313
293
- // ============================================================================
294
- // Matvec 3 (plain + masked + scalable)
295
- // ============================================================================
296
314
// CHECK-LABEL: @masked_matvec_k_mk_m
297
315
// CHECK-SAME: %[[A:.+]]: vector<4x2xf32>
298
316
// CHECK-SAME: %[[X:.+]]: vector<2xf32>
@@ -331,24 +349,6 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
331
349
return %res : vector <[4 ]xf32 >
332
350
}
333
351
334
- // CHECK-LABEL: func @matvec_k_mk_m
335
- // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
336
- // CHECK-SAME: %[[X:.*1]]: vector<2xf32>
337
- // CHECK-SAME: %[[B:.*2]]: vector<2xf32>
338
- // CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
339
- // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
340
- // CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
341
- // CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
342
- // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
343
- // CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
344
- // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
345
- func.func @matvec_k_mk_m (%A: vector <2 x2 xf32 >,
346
- %x: vector <2 xf32 >,
347
- %b: vector <2 xf32 >) -> vector <2 xf32 > {
348
- %0 = vector.contract #matvec_trait_3 %x , %A , %b : vector <2 xf32 >, vector <2 x2 xf32 > into vector <2 xf32 >
349
- return %0 : vector <2 xf32 >
350
- }
351
-
352
352
// ============================================================================
353
353
// Matvec 4 (plain + masked + scalable)
354
354
// ============================================================================
0 commit comments