@@ -406,17 +406,17 @@ func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>, %arg1: vector<1x3xf32>, %
406
406
iterator_types = [" parallel" , " parallel" , " reduction" ]
407
407
}
408
408
409
- #matvec_accesses_0 = [
409
+ #matvec_accesses_1 = [
410
410
affine_map <(m , k ) -> (m , k )>,
411
411
affine_map <(m , k ) -> (k )>,
412
412
affine_map <(m , k ) -> (m )>
413
413
]
414
- #matvec_trait_0 = {
415
- indexing_maps = #matvec_accesses_0 ,
414
+ #matvec_trait_1 = {
415
+ indexing_maps = #matvec_accesses_1 ,
416
416
iterator_types = [" parallel" , " reduction" ]
417
417
}
418
418
419
- // CHECK-LABEL: func.func @masked_extract_contract2 (
419
+ // CHECK-LABEL: func.func @masked_matvec_mk_k_m (
420
420
// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
421
421
// CHECK-SAME: %{{.*}}: vector<3xf32>,
422
422
// CHECK-SAME: %{{.*}}: vector<2xf32>,
@@ -431,17 +431,17 @@ func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>, %arg1: vector<1x3xf32>, %
431
431
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
432
432
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
433
433
434
- func.func @masked_extract_contract2 (%arg0: vector <2 x3 xf32 >,
435
- %arg1: vector <3 xf32 >,
436
- %arg2: vector <2 xf32 >,
437
- %m: vector <2 x3 xi1 >) -> vector <2 xf32 > {
438
- %0 = vector.mask %m { vector.contract #matvec_trait_0 %arg0 , %arg1 , %arg2
434
+ func.func @masked_matvec_mk_k_m (%arg0: vector <2 x3 xf32 >,
435
+ %arg1: vector <3 xf32 >,
436
+ %arg2: vector <2 xf32 >,
437
+ %m: vector <2 x3 xi1 >) -> vector <2 xf32 > {
438
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0 , %arg1 , %arg2
439
439
: vector <2 x3 xf32 >, vector <3 xf32 > into vector <2 xf32 > } : vector <2 x3 xi1 > -> vector <2 xf32 >
440
440
return %0 : vector <2 xf32 >
441
441
}
442
442
443
443
444
- // CHECK-LABEL: func.func @masked_extract_contract2_scalable_parallel_dim (
444
+ // CHECK-LABEL: func.func @masked_matvec_mk_k_m_scalable_parallel_dim (
445
445
// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
446
446
// CHECK-SAME: %{{.*}}: vector<3xf32>,
447
447
// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
@@ -455,57 +455,15 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
455
455
456
456
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
457
457
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
458
- func.func @masked_extract_contract2_scalable_parallel_dim (%arg0: vector <[2 ]x3 xf32 >,
459
- %arg1: vector <3 xf32 >,
460
- %arg2: vector <[2 ]xf32 >,
461
- %m: vector <[2 ]x3 xi1 >) -> vector <[2 ]xf32 > {
462
- %0 = vector.mask %m { vector.contract #matvec_trait_0 %arg0 , %arg1 , %arg2
458
+ func.func @masked_matvec_mk_k_m_scalable_parallel_dim (%arg0: vector <[2 ]x3 xf32 >,
459
+ %arg1: vector <3 xf32 >,
460
+ %arg2: vector <[2 ]xf32 >,
461
+ %m: vector <[2 ]x3 xi1 >) -> vector <[2 ]xf32 > {
462
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0 , %arg1 , %arg2
463
463
: vector <[2 ]x3 xf32 >, vector <3 xf32 > into vector <[2 ]xf32 > } : vector <[2 ]x3 xi1 > -> vector <[2 ]xf32 >
464
464
return %0 : vector <[2 ]xf32 >
465
465
}
466
466
467
- #matvec_accesses_1 = [
468
- affine_map <(m , k ) -> (m , k )>,
469
- affine_map <(m , k ) -> (k )>,
470
- affine_map <(m , k ) -> (m )>
471
- ]
472
- #matvec_trait_1 = {
473
- indexing_maps = #matvec_accesses_1 ,
474
- iterator_types = [" parallel" , " reduction" ]
475
- }
476
-
477
- // CHECK-LABEL: @masked_matvec_mk_k_m
478
- // CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
479
- // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
480
- // CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
481
- // CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
482
- func.func @masked_matvec_mk_k_m (%arg0: vector <4 x2 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <4 xf32 >, %mask: vector <4 x2 xi1 >) -> vector <4 xf32 > {
483
- // CHECK: vector.transpose %[[MASK]]
484
- // CHECK: vector.transpose %[[MAT]]
485
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
486
- %res = vector.mask %mask {
487
- vector.contract #matvec_trait_1 %arg0 , %arg1 , %arg2
488
- : vector <4 x2 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
489
- } : vector <4 x2 xi1 > -> vector <4 xf32 >
490
- return %res : vector <4 xf32 >
491
- }
492
-
493
- // CHECK-LABEL: @masked_matvec_mk_k_m_scalable_parallel_dim
494
- // CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
495
- // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
496
- // CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
497
- // CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
498
- func.func @masked_matvec_mk_k_m_scalable_parallel_dim (%arg0: vector <[4 ]x2 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <[4 ]xf32 >, %mask: vector <[4 ]x2 xi1 >) -> vector <[4 ]xf32 > {
499
- // CHECK: vector.transpose %[[MASK]]
500
- // CHECK: vector.transpose %[[MAT]]
501
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
502
- %res = vector.mask %mask {
503
- vector.contract #matvec_trait_1 %arg0 , %arg1 , %arg2
504
- : vector <[4 ]x2 xf32 >, vector <2 xf32 >, vector <[4 ]xf32 > into vector <[4 ]xf32 >
505
- } : vector <[4 ]x2 xi1 > -> vector <[4 ]xf32 >
506
- return %res : vector <[4 ]xf32 >
507
- }
508
-
509
467
#matvec_accesses_2 = [
510
468
affine_map <(m , k ) -> (k , m )>,
511
469
affine_map <(m , k ) -> (k )>,
0 commit comments