@@ -313,6 +313,16 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
313
313
return %0 : vector <3 x2 xf32 >
314
314
}
315
315
316
+ #matvec_accesses_1 = [
317
+ affine_map <(m , k ) -> (m , k )>,
318
+ affine_map <(m , k ) -> (k )>,
319
+ affine_map <(m , k ) -> (m )>
320
+ ]
321
+ #matvec_trait_1 = {
322
+ indexing_maps = #matvec_accesses_1 ,
323
+ iterator_types = [" parallel" , " reduction" ]
324
+ }
325
+
316
326
// CHECK-LABEL: @masked_matvec_mk_k_m
317
327
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
318
328
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -323,17 +333,38 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
323
333
// CHECK: vector.transpose %[[MAT]]
324
334
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
325
335
%res = vector.mask %mask {
326
- vector.contract {
327
- indexing_maps = [affine_map <(m , k ) -> (m , k )>,
328
- affine_map <(m , k ) -> (k )>,
329
- affine_map <(m , k ) -> (m )>],
330
- iterator_types = [" parallel" , " reduction" ],
331
- kind = #vector.kind <add >
332
- } %arg0 , %arg1 , %arg2 : vector <4 x2 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
336
+ vector.contract #matvec_trait_1 %arg0 , %arg1 , %arg2
337
+ : vector <4 x2 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
333
338
} : vector <4 x2 xi1 > -> vector <4 xf32 >
334
339
return %res : vector <4 xf32 >
335
340
}
336
341
342
+ // CHECK-LABEL: @masked_matvec_mk_k_m_scalable_parallel_dim
343
+ // CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
344
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
345
+ // CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
346
+ // CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
347
+ func.func @masked_matvec_mk_k_m_scalable_parallel_dim (%arg0: vector <[4 ]x2 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <[4 ]xf32 >, %mask: vector <[4 ]x2 xi1 >) -> vector <[4 ]xf32 > {
348
+ // CHECK: vector.transpose %[[MASK]]
349
+ // CHECK: vector.transpose %[[MAT]]
350
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
351
+ %res = vector.mask %mask {
352
+ vector.contract #matvec_trait_1 %arg0 , %arg1 , %arg2
353
+ : vector <[4 ]x2 xf32 >, vector <2 xf32 >, vector <[4 ]xf32 > into vector <[4 ]xf32 >
354
+ } : vector <[4 ]x2 xi1 > -> vector <[4 ]xf32 >
355
+ return %res : vector <[4 ]xf32 >
356
+ }
357
+
358
+ #matvec_accesses_2 = [
359
+ affine_map <(m , k ) -> (k , m )>,
360
+ affine_map <(m , k ) -> (k )>,
361
+ affine_map <(m , k ) -> (m )>
362
+ ]
363
+ #matvec_trait_2 = {
364
+ indexing_maps = #matvec_accesses_2 ,
365
+ iterator_types = [" parallel" , " reduction" ]
366
+ }
367
+
337
368
// CHECK-LABEL: @masked_matvec_km_k_m
338
369
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
339
370
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -344,17 +375,38 @@ func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
344
375
// CHECK-NOT: vector.transpose %[[MAT]]
345
376
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
346
377
%res = vector.mask %mask {
347
- vector.contract {
348
- indexing_maps = [affine_map <(m , k ) -> (k , m )>,
349
- affine_map <(m , k ) -> (k )>,
350
- affine_map <(m , k ) -> (m )>],
351
- iterator_types = [" parallel" , " reduction" ],
352
- kind = #vector.kind <add >
353
- } %arg0 , %arg1 , %arg2 : vector <2 x4 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
378
+ vector.contract #matvec_trait_2 %arg0 , %arg1 , %arg2
379
+ : vector <2 x4 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
354
380
} : vector <4 x2 xi1 > -> vector <4 xf32 >
355
381
return %res : vector <4 xf32 >
356
382
}
357
383
384
+ // CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
385
+ // CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
386
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
387
+ // CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
388
+ // CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
389
+ func.func @masked_matvec_km_k_m_scalable_parallel_dim (%arg0: vector <2 x[4 ]xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <[4 ]xf32 >, %mask: vector <[4 ]x2 xi1 >) -> vector <[4 ]xf32 > {
390
+ // CHECK: vector.transpose %[[MASK]]
391
+ // CHECK-NOT: vector.transpose %[[MAT]]
392
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
393
+ %res = vector.mask %mask {
394
+ vector.contract #matvec_trait_2 %arg0 , %arg1 , %arg2
395
+ : vector <2 x[4 ]xf32 >, vector <2 xf32 >, vector <[4 ]xf32 > into vector <[4 ]xf32 >
396
+ } : vector <[4 ]x2 xi1 > -> vector <[4 ]xf32 >
397
+ return %res : vector <[4 ]xf32 >
398
+ }
399
+
400
+ #matvec_accesses_3 = [
401
+ affine_map <(m , k ) -> (k )>,
402
+ affine_map <(m , k ) -> (m , k )>,
403
+ affine_map <(m , k ) -> (m )>
404
+ ]
405
+ #matvec_trait_3 = {
406
+ indexing_maps = #matvec_accesses_3 ,
407
+ iterator_types = [" parallel" , " reduction" ]
408
+ }
409
+
358
410
// CHECK-LABEL: @masked_matvec_k_mk_m
359
411
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
360
412
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -365,17 +417,54 @@ func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
365
417
// CHECK: vector.transpose %[[MAT]]
366
418
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
367
419
%res = vector.mask %mask {
368
- vector.contract {
369
- indexing_maps = [affine_map <(m , k ) -> (k )>,
370
- affine_map <(m , k ) -> (m , k )>,
371
- affine_map <(m , k ) -> (m )>],
372
- iterator_types = [" parallel" , " reduction" ],
373
- kind = #vector.kind <add >
374
- } %arg1 , %arg0 , %arg2 : vector <2 xf32 >, vector <4 x2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
420
+ vector.contract #matvec_trait_3 %arg1 , %arg0 , %arg2
421
+ : vector <2 xf32 >, vector <4 x2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
375
422
} : vector <4 x2 xi1 > -> vector <4 xf32 >
376
423
return %res : vector <4 xf32 >
377
424
}
378
425
426
+ // CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
427
+ // CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
428
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
429
+ // CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
430
+ // CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
431
+ func.func @masked_matvec_k_mk_m_scalable_parallel_dim (%arg0: vector <[4 ]x2 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <[4 ]xf32 >, %mask: vector <[4 ]x2 xi1 >) -> vector <[4 ]xf32 > {
432
+ // CHECK: vector.transpose %[[MASK]]
433
+ // CHECK: vector.transpose %[[MAT]]
434
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
435
+ %res = vector.mask %mask {
436
+ vector.contract #matvec_trait_3 %arg1 , %arg0 , %arg2
437
+ : vector <2 xf32 >, vector <[4 ]x2 xf32 >, vector <[4 ]xf32 > into vector <[4 ]xf32 >
438
+ } : vector <[4 ]x2 xi1 > -> vector <[4 ]xf32 >
439
+ return %res : vector <[4 ]xf32 >
440
+ }
441
+
442
+ #matvec_accesses_4 = [
443
+ affine_map <(m , k ) -> (k )>,
444
+ affine_map <(m , k ) -> (k , m )>,
445
+ affine_map <(m , k ) -> (m )>
446
+ ]
447
+ #matvec_trait_4 = {
448
+ indexing_maps = #matvec_accesses_4 ,
449
+ iterator_types = [" parallel" , " reduction" ]
450
+ }
451
+
452
+ // CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
453
+ // CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
454
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
455
+ // CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
456
+ // CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
457
+ func.func @masked_matvec_k_km_m_scalable_parallel_dim (%arg0: vector <2 x[4 ]xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <[4 ]xf32 >, %mask: vector <[4 ]x2 xi1 >) -> vector <[4 ]xf32 > {
458
+ // CHECK: vector.transpose %[[MASK]]
459
+ // CHECK-NOT: vector.transpose %[[MAT]]
460
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
461
+ %res = vector.mask %mask {
462
+ vector.contract #matvec_trait_4 %arg1 , %arg0 , %arg2
463
+ : vector <2 xf32 >, vector <2 x[4 ]xf32 >, vector <[4 ]xf32 > into vector <[4 ]xf32 >
464
+ } : vector <[4 ]x2 xi1 > -> vector <[4 ]xf32 >
465
+ return %res : vector <[4 ]xf32 >
466
+ }
467
+
379
468
// CHECK-LABEL: @masked_matvec_k_km_m
380
469
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
381
470
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -386,17 +475,22 @@ func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
386
475
// CHECK-NOT: vector.transpose %[[MAT]]
387
476
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
388
477
%res = vector.mask %mask {
389
- vector.contract {
390
- indexing_maps = [affine_map <(m , k ) -> (k )>,
391
- affine_map <(m , k ) -> (k , m )>,
392
- affine_map <(m , k ) -> (m )>],
393
- iterator_types = [" parallel" , " reduction" ],
394
- kind = #vector.kind <add >
395
- } %arg1 , %arg0 , %arg2 : vector <2 xf32 >, vector <2 x4 xf32 >, vector <4 xf32 > into vector <4 xf32 >
478
+ vector.contract #matvec_trait_4 %arg1 , %arg0 , %arg2
479
+ : vector <2 xf32 >, vector <2 x4 xf32 >, vector <4 xf32 > into vector <4 xf32 >
396
480
} : vector <4 x2 xi1 > -> vector <4 xf32 >
397
481
return %res : vector <4 xf32 >
398
482
}
399
483
484
+ #matvec_accesses_5 = [
485
+ affine_map <(k , m ) -> (m , k )>,
486
+ affine_map <(k , m ) -> (k )>,
487
+ affine_map <(k , m ) -> (m )>
488
+ ]
489
+ #matvec_trait_5 = {
490
+ indexing_maps = #matvec_accesses_5 ,
491
+ iterator_types = [" reduction" , " parallel" ]
492
+ }
493
+
400
494
// CHECK-LABEL: @masked_tmatvec_mk_k_m
401
495
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
402
496
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -407,17 +501,38 @@ func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
407
501
// CHECK-NOT: vector.transpose %[[MASK]]
408
502
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
409
503
%res = vector.mask %mask {
410
- vector.contract {
411
- indexing_maps = [affine_map <(k , m ) -> (m , k )>,
412
- affine_map <(k , m ) -> (k )>,
413
- affine_map <(k , m ) -> (m )>],
414
- iterator_types = [" reduction" , " parallel" ],
415
- kind = #vector.kind <add >
416
- } %arg0 , %arg1 , %arg2 : vector <4 x2 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
504
+ vector.contract #matvec_trait_5 %arg0 , %arg1 , %arg2
505
+ : vector <4 x2 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
417
506
} : vector <2 x4 xi1 > -> vector <4 xf32 >
418
507
return %res : vector <4 xf32 >
419
508
}
420
509
510
+ // CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
511
+ // CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
512
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
513
+ // CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
514
+ // CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
515
+ func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim (%arg0: vector <[4 ]x2 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <[4 ]xf32 >, %mask: vector <2 x[4 ]xi1 >) -> vector <[4 ]xf32 > {
516
+ // CHECK: vector.transpose %[[MAT]]
517
+ // CHECK-NOT: vector.transpose %[[MASK]]
518
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
519
+ %res = vector.mask %mask {
520
+ vector.contract #matvec_trait_5 %arg0 , %arg1 , %arg2
521
+ : vector <[4 ]x2 xf32 >, vector <2 xf32 >, vector <[4 ]xf32 > into vector <[4 ]xf32 >
522
+ } : vector <2 x[4 ]xi1 > -> vector <[4 ]xf32 >
523
+ return %res : vector <[4 ]xf32 >
524
+ }
525
+
526
+ #matvec_accesses_6 = [
527
+ affine_map <(k , m ) -> (k , m )>,
528
+ affine_map <(k , m ) -> (k )>,
529
+ affine_map <(k , m ) -> (m )>
530
+ ]
531
+ #matvec_trait_6 = {
532
+ indexing_maps = #matvec_accesses_6 ,
533
+ iterator_types = [" reduction" , " parallel" ]
534
+ }
535
+
421
536
// CHECK-LABEL: @masked_tmatvec_km_k_m
422
537
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
423
538
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -428,17 +543,38 @@ func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
428
543
// CHECK-NOT: vector.transpose %[[MASK]]
429
544
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
430
545
%res = vector.mask %mask {
431
- vector.contract {
432
- indexing_maps = [affine_map <(k , m ) -> (k , m )>,
433
- affine_map <(k , m ) -> (k )>,
434
- affine_map <(k , m ) -> (m )>],
435
- iterator_types = [" reduction" , " parallel" ],
436
- kind = #vector.kind <add >
437
- } %arg0 , %arg1 , %arg2 : vector <2 x4 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
546
+ vector.contract #matvec_trait_6 %arg0 , %arg1 , %arg2
547
+ : vector <2 x4 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
438
548
} : vector <2 x4 xi1 > -> vector <4 xf32 >
439
549
return %res : vector <4 xf32 >
440
550
}
441
551
552
+ // CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
553
+ // CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
554
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
555
+ // CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
556
+ // CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
557
+ func.func @masked_tmatvec_km_k_m_scalable_parallel_dim (%arg0: vector <2 x[4 ]xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <[4 ]xf32 >, %mask: vector <2 x[4 ]xi1 >) -> vector <[4 ]xf32 > {
558
+ // CHECK-NOT: vector.transpose %[[MAT]]
559
+ // CHECK-NOT: vector.transpose %[[MASK]]
560
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
561
+ %res = vector.mask %mask {
562
+ vector.contract #matvec_trait_6 %arg0 , %arg1 , %arg2
563
+ : vector <2 x[4 ]xf32 >, vector <2 xf32 >, vector <[4 ]xf32 > into vector <[4 ]xf32 >
564
+ } : vector <2 x[4 ]xi1 > -> vector <[4 ]xf32 >
565
+ return %res : vector <[4 ]xf32 >
566
+ }
567
+
568
+ #matvec_accesses_7 = [
569
+ affine_map <(k , m ) -> (k )>,
570
+ affine_map <(k , m ) -> (m , k )>,
571
+ affine_map <(k , m ) -> (m )>
572
+ ]
573
+ #matvec_trait_7 = {
574
+ indexing_maps = #matvec_accesses_7 ,
575
+ iterator_types = [" reduction" , " parallel" ]
576
+ }
577
+
442
578
// CHECK-LABEL: @masked_tmatvec_k_mk_m
443
579
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
444
580
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -449,17 +585,38 @@ func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
449
585
// CHECK-NOT: vector.transpose %[[MASK]]
450
586
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
451
587
%res = vector.mask %mask {
452
- vector.contract {
453
- indexing_maps = [affine_map <(k , m ) -> (k )>,
454
- affine_map <(k , m ) -> (m , k )>,
455
- affine_map <(k , m ) -> (m )>],
456
- iterator_types = [" reduction" , " parallel" ],
457
- kind = #vector.kind <add >
458
- } %arg1 , %arg0 , %arg2 : vector <2 xf32 >, vector <4 x2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
588
+ vector.contract #matvec_trait_7 %arg1 , %arg0 , %arg2
589
+ : vector <2 xf32 >, vector <4 x2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
459
590
} : vector <2 x4 xi1 > -> vector <4 xf32 >
460
591
return %res : vector <4 xf32 >
461
592
}
462
593
594
+ // CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim
595
+ // CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
596
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
597
+ // CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
598
+ // CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
599
+ func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim (%arg0: vector <[4 ]x2 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <[4 ]xf32 >, %mask: vector <2 x[4 ]xi1 >) -> vector <[4 ]xf32 > {
600
+ // CHECK: vector.transpose %[[MAT]]
601
+ // CHECK-NOT: vector.transpose %[[MASK]]
602
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
603
+ %res = vector.mask %mask {
604
+ vector.contract #matvec_trait_7 %arg1 , %arg0 , %arg2
605
+ : vector <2 xf32 >, vector <[4 ]x2 xf32 >, vector <[4 ]xf32 > into vector <[4 ]xf32 >
606
+ } : vector <2 x[4 ]xi1 > -> vector <[4 ]xf32 >
607
+ return %res : vector <[4 ]xf32 >
608
+ }
609
+
610
+ #matvec_accesses_8 = [
611
+ affine_map <(k , m ) -> (k )>,
612
+ affine_map <(k , m ) -> (k , m )>,
613
+ affine_map <(k , m ) -> (m )>
614
+ ]
615
+ #matvec_trait_8 = {
616
+ indexing_maps = #matvec_accesses_8 ,
617
+ iterator_types = [" reduction" , " parallel" ]
618
+ }
619
+
463
620
// CHECK-LABEL: @masked_tmatvec_k_km_m
464
621
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
465
622
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -470,17 +627,28 @@ func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
470
627
// CHECK-NOT: vector.transpose %[[MASK]]
471
628
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
472
629
%res = vector.mask %mask {
473
- vector.contract {
474
- indexing_maps = [affine_map <(k , m ) -> (k )>,
475
- affine_map <(k , m ) -> (k , m )>,
476
- affine_map <(k , m ) -> (m )>],
477
- iterator_types = [" reduction" , " parallel" ],
478
- kind = #vector.kind <add >
479
- } %arg1 , %arg0 , %arg2 : vector <2 xf32 >, vector <2 x4 xf32 >, vector <4 xf32 > into vector <4 xf32 >
630
+ vector.contract #matvec_trait_8 %arg1 , %arg0 , %arg2
631
+ : vector <2 xf32 >, vector <2 x4 xf32 >, vector <4 xf32 > into vector <4 xf32 >
480
632
} : vector <2 x4 xi1 > -> vector <4 xf32 >
481
633
return %res : vector <4 xf32 >
482
634
}
483
635
636
+ // CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim
637
+ // CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
638
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
639
+ // CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
640
+ // CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
641
+ func.func @masked_tmatvec_k_km_m_scalable_parallel_dim (%arg0: vector <2 x[4 ]xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <[4 ]xf32 >, %mask: vector <2 x[4 ]xi1 >) -> vector <[4 ]xf32 > {
642
+ // CHECK-NOT: vector.transpose %[[MAT]]
643
+ // CHECK-NOT: vector.transpose %[[MASK]]
644
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
645
+ %res = vector.mask %mask {
646
+ vector.contract #matvec_trait_8 %arg1 , %arg0 , %arg2
647
+ : vector <2 xf32 >, vector <2 x[4 ]xf32 >, vector <[4 ]xf32 > into vector <[4 ]xf32 >
648
+ } : vector <2 x[4 ]xi1 > -> vector <[4 ]xf32 >
649
+ return %res : vector <[4 ]xf32 >
650
+ }
651
+
484
652
485
653
module attributes {transform.with_named_sequence } {
486
654
transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
0 commit comments