Skip to content

Commit c262dec

Browse files
committed
[mlir][vector] Update v.contract -> v.outerproduct tests
Tests for conversions from vector.contract to vector.outerproduct for _matvec_ operations are updated with cases for scalable vectors. This patch updates one specific test file: vector-contract-to-outerproduct-transforms.mlir. The remaining _matmul_ operations in this file will be updated in a separate patch. Only the parallel dimension is made scalable. Making the reduction dimension scalable would lead to different patterns without vector.outerproduct (that would need to be added to some other file).
1 parent f24c443 commit c262dec

File tree

1 file changed

+224
-56
lines changed

1 file changed

+224
-56
lines changed

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir

Lines changed: 224 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,16 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
313313
return %0 : vector<3x2xf32>
314314
}
315315

316+
#matvec_accesses_1 = [
317+
affine_map<(m, k) -> (m, k)>,
318+
affine_map<(m, k) -> (k)>,
319+
affine_map<(m, k) -> (m)>
320+
]
321+
#matvec_trait_1 = {
322+
indexing_maps = #matvec_accesses_1,
323+
iterator_types = ["parallel", "reduction"]
324+
}
325+
316326
// CHECK-LABEL: @masked_matvec_mk_k_m
317327
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
318328
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -323,17 +333,38 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
323333
// CHECK: vector.transpose %[[MAT]]
324334
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
325335
%res = vector.mask %mask {
326-
vector.contract {
327-
indexing_maps = [affine_map<(m, k) -> (m, k)>,
328-
affine_map<(m, k) -> (k)>,
329-
affine_map<(m, k) -> (m)>],
330-
iterator_types = ["parallel", "reduction"],
331-
kind = #vector.kind<add>
332-
} %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
336+
vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
337+
: vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
333338
} : vector<4x2xi1> -> vector<4xf32>
334339
return %res : vector<4xf32>
335340
}
336341

342+
// CHECK-LABEL: @masked_matvec_mk_k_m_scalable_parallel_dim
343+
// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
344+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
345+
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
346+
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
347+
func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
348+
// CHECK: vector.transpose %[[MASK]]
349+
// CHECK: vector.transpose %[[MAT]]
350+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
351+
%res = vector.mask %mask {
352+
vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
353+
: vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
354+
} : vector<[4]x2xi1> -> vector<[4]xf32>
355+
return %res : vector<[4]xf32>
356+
}
357+
358+
#matvec_accesses_2 = [
359+
affine_map<(m, k) -> (k, m)>,
360+
affine_map<(m, k) -> (k)>,
361+
affine_map<(m, k) -> (m)>
362+
]
363+
#matvec_trait_2 = {
364+
indexing_maps = #matvec_accesses_2,
365+
iterator_types = ["parallel", "reduction"]
366+
}
367+
337368
// CHECK-LABEL: @masked_matvec_km_k_m
338369
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
339370
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -344,17 +375,38 @@ func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
344375
// CHECK-NOT: vector.transpose %[[MAT]]
345376
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
346377
%res = vector.mask %mask {
347-
vector.contract {
348-
indexing_maps = [affine_map<(m, k) -> (k, m)>,
349-
affine_map<(m, k) -> (k)>,
350-
affine_map<(m, k) -> (m)>],
351-
iterator_types = ["parallel", "reduction"],
352-
kind = #vector.kind<add>
353-
} %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
378+
vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
379+
: vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
354380
} : vector<4x2xi1> -> vector<4xf32>
355381
return %res : vector<4xf32>
356382
}
357383

384+
// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
385+
// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
386+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
387+
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
388+
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
389+
func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
390+
// CHECK: vector.transpose %[[MASK]]
391+
// CHECK-NOT: vector.transpose %[[MAT]]
392+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
393+
%res = vector.mask %mask {
394+
vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
395+
: vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
396+
} : vector<[4]x2xi1> -> vector<[4]xf32>
397+
return %res : vector<[4]xf32>
398+
}
399+
400+
#matvec_accesses_3 = [
401+
affine_map<(m, k) -> (k)>,
402+
affine_map<(m, k) -> (m, k)>,
403+
affine_map<(m, k) -> (m)>
404+
]
405+
#matvec_trait_3 = {
406+
indexing_maps = #matvec_accesses_3,
407+
iterator_types = ["parallel", "reduction"]
408+
}
409+
358410
// CHECK-LABEL: @masked_matvec_k_mk_m
359411
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
360412
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -365,17 +417,54 @@ func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
365417
// CHECK: vector.transpose %[[MAT]]
366418
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
367419
%res = vector.mask %mask {
368-
vector.contract {
369-
indexing_maps = [affine_map<(m, k) -> (k)>,
370-
affine_map<(m, k) -> (m, k)>,
371-
affine_map<(m, k) -> (m)>],
372-
iterator_types = ["parallel", "reduction"],
373-
kind = #vector.kind<add>
374-
} %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
420+
vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
421+
: vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
375422
} : vector<4x2xi1> -> vector<4xf32>
376423
return %res : vector<4xf32>
377424
}
378425

426+
// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
427+
// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
428+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
429+
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
430+
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
431+
func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
432+
// CHECK: vector.transpose %[[MASK]]
433+
// CHECK: vector.transpose %[[MAT]]
434+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
435+
%res = vector.mask %mask {
436+
vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
437+
: vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
438+
} : vector<[4]x2xi1> -> vector<[4]xf32>
439+
return %res : vector<[4]xf32>
440+
}
441+
442+
#matvec_accesses_4 = [
443+
affine_map<(m, k) -> (k)>,
444+
affine_map<(m, k) -> (k, m)>,
445+
affine_map<(m, k) -> (m)>
446+
]
447+
#matvec_trait_4 = {
448+
indexing_maps = #matvec_accesses_4,
449+
iterator_types = ["parallel", "reduction"]
450+
}
451+
452+
// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
453+
// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
454+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
455+
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
456+
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
457+
func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
458+
// CHECK: vector.transpose %[[MASK]]
459+
// CHECK-NOT: vector.transpose %[[MAT]]
460+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
461+
%res = vector.mask %mask {
462+
vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
463+
: vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
464+
} : vector<[4]x2xi1> -> vector<[4]xf32>
465+
return %res : vector<[4]xf32>
466+
}
467+
379468
// CHECK-LABEL: @masked_matvec_k_km_m
380469
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
381470
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -386,17 +475,22 @@ func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
386475
// CHECK-NOT: vector.transpose %[[MAT]]
387476
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
388477
%res = vector.mask %mask {
389-
vector.contract {
390-
indexing_maps = [affine_map<(m, k) -> (k)>,
391-
affine_map<(m, k) -> (k, m)>,
392-
affine_map<(m, k) -> (m)>],
393-
iterator_types = ["parallel", "reduction"],
394-
kind = #vector.kind<add>
395-
} %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
478+
vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
479+
: vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
396480
} : vector<4x2xi1> -> vector<4xf32>
397481
return %res : vector<4xf32>
398482
}
399483

484+
#matvec_accesses_5 = [
485+
affine_map<(k, m) -> (m, k)>,
486+
affine_map<(k, m) -> (k)>,
487+
affine_map<(k, m) -> (m)>
488+
]
489+
#matvec_trait_5 = {
490+
indexing_maps = #matvec_accesses_5,
491+
iterator_types = ["reduction", "parallel"]
492+
}
493+
400494
// CHECK-LABEL: @masked_tmatvec_mk_k_m
401495
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
402496
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -407,17 +501,38 @@ func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
407501
// CHECK-NOT: vector.transpose %[[MASK]]
408502
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
409503
%res = vector.mask %mask {
410-
vector.contract {
411-
indexing_maps = [affine_map<(k, m) -> (m, k)>,
412-
affine_map<(k, m) -> (k)>,
413-
affine_map<(k, m) -> (m)>],
414-
iterator_types = ["reduction", "parallel"],
415-
kind = #vector.kind<add>
416-
} %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
504+
vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
505+
: vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
417506
} : vector<2x4xi1> -> vector<4xf32>
418507
return %res : vector<4xf32>
419508
}
420509

510+
// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
511+
// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
512+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
513+
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
514+
// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
515+
func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
516+
// CHECK: vector.transpose %[[MAT]]
517+
// CHECK-NOT: vector.transpose %[[MASK]]
518+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
519+
%res = vector.mask %mask {
520+
vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
521+
: vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
522+
} : vector<2x[4]xi1> -> vector<[4]xf32>
523+
return %res : vector<[4]xf32>
524+
}
525+
526+
#matvec_accesses_6 = [
527+
affine_map<(k, m) -> (k, m)>,
528+
affine_map<(k, m) -> (k)>,
529+
affine_map<(k, m) -> (m)>
530+
]
531+
#matvec_trait_6 = {
532+
indexing_maps = #matvec_accesses_6,
533+
iterator_types = ["reduction", "parallel"]
534+
}
535+
421536
// CHECK-LABEL: @masked_tmatvec_km_k_m
422537
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
423538
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -428,17 +543,38 @@ func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
428543
// CHECK-NOT: vector.transpose %[[MASK]]
429544
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
430545
%res = vector.mask %mask {
431-
vector.contract {
432-
indexing_maps = [affine_map<(k, m) -> (k, m)>,
433-
affine_map<(k, m) -> (k)>,
434-
affine_map<(k, m) -> (m)>],
435-
iterator_types = ["reduction", "parallel"],
436-
kind = #vector.kind<add>
437-
} %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
546+
vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
547+
: vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
438548
} : vector<2x4xi1> -> vector<4xf32>
439549
return %res : vector<4xf32>
440550
}
441551

552+
// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
553+
// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
554+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
555+
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
556+
// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
557+
func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
558+
// CHECK-NOT: vector.transpose %[[MAT]]
559+
// CHECK-NOT: vector.transpose %[[MASK]]
560+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
561+
%res = vector.mask %mask {
562+
vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
563+
: vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
564+
} : vector<2x[4]xi1> -> vector<[4]xf32>
565+
return %res : vector<[4]xf32>
566+
}
567+
568+
#matvec_accesses_7 = [
569+
affine_map<(k, m) -> (k)>,
570+
affine_map<(k, m) -> (m, k)>,
571+
affine_map<(k, m) -> (m)>
572+
]
573+
#matvec_trait_7 = {
574+
indexing_maps = #matvec_accesses_7,
575+
iterator_types = ["reduction", "parallel"]
576+
}
577+
442578
// CHECK-LABEL: @masked_tmatvec_k_mk_m
443579
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
444580
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -449,17 +585,38 @@ func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
449585
// CHECK-NOT: vector.transpose %[[MASK]]
450586
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
451587
%res = vector.mask %mask {
452-
vector.contract {
453-
indexing_maps = [affine_map<(k, m) -> (k)>,
454-
affine_map<(k, m) -> (m, k)>,
455-
affine_map<(k, m) -> (m)>],
456-
iterator_types = ["reduction", "parallel"],
457-
kind = #vector.kind<add>
458-
} %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
588+
vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
589+
: vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
459590
} : vector<2x4xi1> -> vector<4xf32>
460591
return %res : vector<4xf32>
461592
}
462593

594+
// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim
595+
// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
596+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
597+
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
598+
// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
599+
func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
600+
// CHECK: vector.transpose %[[MAT]]
601+
// CHECK-NOT: vector.transpose %[[MASK]]
602+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
603+
%res = vector.mask %mask {
604+
vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
605+
: vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
606+
} : vector<2x[4]xi1> -> vector<[4]xf32>
607+
return %res : vector<[4]xf32>
608+
}
609+
610+
#matvec_accesses_8 = [
611+
affine_map<(k, m) -> (k)>,
612+
affine_map<(k, m) -> (k, m)>,
613+
affine_map<(k, m) -> (m)>
614+
]
615+
#matvec_trait_8 = {
616+
indexing_maps = #matvec_accesses_8,
617+
iterator_types = ["reduction", "parallel"]
618+
}
619+
463620
// CHECK-LABEL: @masked_tmatvec_k_km_m
464621
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
465622
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
@@ -470,17 +627,28 @@ func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
470627
// CHECK-NOT: vector.transpose %[[MASK]]
471628
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
472629
%res = vector.mask %mask {
473-
vector.contract {
474-
indexing_maps = [affine_map<(k, m) -> (k)>,
475-
affine_map<(k, m) -> (k, m)>,
476-
affine_map<(k, m) -> (m)>],
477-
iterator_types = ["reduction", "parallel"],
478-
kind = #vector.kind<add>
479-
} %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
630+
vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
631+
: vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
480632
} : vector<2x4xi1> -> vector<4xf32>
481633
return %res : vector<4xf32>
482634
}
483635

636+
// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim
637+
// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
638+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
639+
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
640+
// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
641+
func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
642+
// CHECK-NOT: vector.transpose %[[MAT]]
643+
// CHECK-NOT: vector.transpose %[[MASK]]
644+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
645+
%res = vector.mask %mask {
646+
vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
647+
: vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
648+
} : vector<2x[4]xi1> -> vector<[4]xf32>
649+
return %res : vector<[4]xf32>
650+
}
651+
484652

485653
module attributes {transform.with_named_sequence} {
486654
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {

0 commit comments

Comments
 (0)