@@ -585,35 +585,9 @@ func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
585
585
586
586
// -----
587
587
588
- // CHECK-LABEL: @vector_outerproduct_masked_f64
589
- // CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
590
- func.func @vector_outerproduct_masked_f64 (%lhs : vector <[2 ]xf64 >, %rhs : vector <[2 ]xf64 >, %acc : vector <[2 ]x[2 ]xf64 >, %dim0: index , %dim1: index ) {
591
- %mask = vector.create_mask %dim0 , %dim1 : vector <[2 ]x[2 ]xi1 >
592
- // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[2]xi1>
593
- // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[2]xi1>
594
- // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>
595
- %result = vector.mask %mask { vector.outerproduct %lhs , %rhs , %acc {kind = #vector.kind <add >} : vector <[2 ]xf64 >, vector <[2 ]xf64 > } : vector <[2 ]x[2 ]xi1 > -> vector <[2 ]x[2 ]xf64 >
596
- " prevent.dce" (%result ) : (vector <[2 ]x[2 ]xf64 >) -> ()
597
- }
598
-
599
- // -----
600
-
601
- // CHECK-LABEL: @vector_outerproduct_masked_f32
602
- // CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
603
- func.func @vector_outerproduct_masked_f32 (%lhs : vector <[4 ]xf32 >, %rhs : vector <[4 ]xf32 >, %acc : vector <[4 ]x[4 ]xf32 >, %dim0: index , %dim1: index ) {
604
- %mask = vector.create_mask %dim0 , %dim1 : vector <[4 ]x[4 ]xi1 >
605
- // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
606
- // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
607
- // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>
608
- %result = vector.mask %mask { vector.outerproduct %lhs , %rhs , %acc {kind = #vector.kind <add >} : vector <[4 ]xf32 >, vector <[4 ]xf32 > } : vector <[4 ]x[4 ]xi1 > -> vector <[4 ]x[4 ]xf32 >
609
- " prevent.dce" (%result ) : (vector <[4 ]x[4 ]xf32 >) -> ()
610
- }
611
-
612
- // -----
613
-
614
588
// CHECK-LABEL: @vector_outerproduct_masked_f16
615
589
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
616
- func.func @vector_outerproduct_masked_f16 (%lhs : vector <[8 ]xf16 >, %rhs : vector <[8 ]xf16 >, %acc : vector <[8 ]x[8 ]xf16 >, %dim0: index , %dim1: index ) {
590
+ func.func @vector_outerproduct_masked_f16 (%lhs : vector <[8 ]xf16 >, %rhs : vector <[8 ]xf16 >, %acc : vector <[8 ]x[8 ]xf16 >, %dim0 : index , %dim1 : index ) {
617
591
%mask = vector.create_mask %dim0 , %dim1 : vector <[8 ]x[8 ]xi1 >
618
592
// CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
619
593
// CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
@@ -626,7 +600,7 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
626
600
627
601
// CHECK-LABEL: @vector_outerproduct_masked_bf16
628
602
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
629
- func.func @vector_outerproduct_masked_bf16 (%lhs : vector <[8 ]xbf16 >, %rhs : vector <[8 ]xbf16 >, %acc : vector <[8 ]x[8 ]xbf16 >, %dim0: index , %dim1: index ) {
603
+ func.func @vector_outerproduct_masked_bf16 (%lhs : vector <[8 ]xbf16 >, %rhs : vector <[8 ]xbf16 >, %acc : vector <[8 ]x[8 ]xbf16 >, %dim0 : index , %dim1 : index ) {
630
604
%mask = vector.create_mask %dim0 , %dim1 : vector <[8 ]x[8 ]xi1 >
631
605
// CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
632
606
// CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
@@ -637,22 +611,28 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
637
611
638
612
// -----
639
613
640
- // CHECK-LABEL: @vector_outerproduct_f64
641
- // CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>
642
- func.func @vector_outerproduct_f64 (%lhs : vector <[2 ]xf64 >, %rhs : vector <[2 ]xf64 >, %acc : vector <[2 ]x[2 ]xf64 >) {
643
- // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>
644
- %result = vector.outerproduct %lhs , %rhs , %acc {kind = #vector.kind <add >} : vector <[2 ]xf64 >, vector <[2 ]xf64 >
645
- " prevent.dce" (%result ) : (vector <[2 ]x[2 ]xf64 >) -> ()
614
+ // CHECK-LABEL: @vector_outerproduct_masked_f32
615
+ // CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
616
+ func.func @vector_outerproduct_masked_f32 (%lhs : vector <[4 ]xf32 >, %rhs : vector <[4 ]xf32 >, %acc : vector <[4 ]x[4 ]xf32 >, %dim0 : index , %dim1 : index ) {
617
+ %mask = vector.create_mask %dim0 , %dim1 : vector <[4 ]x[4 ]xi1 >
618
+ // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
619
+ // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
620
+ // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>
621
+ %result = vector.mask %mask { vector.outerproduct %lhs , %rhs , %acc {kind = #vector.kind <add >} : vector <[4 ]xf32 >, vector <[4 ]xf32 > } : vector <[4 ]x[4 ]xi1 > -> vector <[4 ]x[4 ]xf32 >
622
+ " prevent.dce" (%result ) : (vector <[4 ]x[4 ]xf32 >) -> ()
646
623
}
647
624
648
625
// -----
649
626
650
- // CHECK-LABEL: @vector_outerproduct_f32
651
- // CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>
652
- func.func @vector_outerproduct_f32 (%lhs : vector <[4 ]xf32 >, %rhs : vector <[4 ]xf32 >, %acc : vector <[4 ]x[4 ]xf32 >) {
653
- // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>
654
- %result = vector.outerproduct %lhs , %rhs , %acc {kind = #vector.kind <add >} : vector <[4 ]xf32 >, vector <[4 ]xf32 >
655
- " prevent.dce" (%result ) : (vector <[4 ]x[4 ]xf32 >) -> ()
627
+ // CHECK-LABEL: @vector_outerproduct_masked_f64
628
+ // CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
629
+ func.func @vector_outerproduct_masked_f64 (%lhs : vector <[2 ]xf64 >, %rhs : vector <[2 ]xf64 >, %acc : vector <[2 ]x[2 ]xf64 >, %dim0 : index , %dim1 : index ) {
630
+ %mask = vector.create_mask %dim0 , %dim1 : vector <[2 ]x[2 ]xi1 >
631
+ // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[2]xi1>
632
+ // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[2]xi1>
633
+ // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>
634
+ %result = vector.mask %mask { vector.outerproduct %lhs , %rhs , %acc {kind = #vector.kind <add >} : vector <[2 ]xf64 >, vector <[2 ]xf64 > } : vector <[2 ]x[2 ]xi1 > -> vector <[2 ]x[2 ]xf64 >
635
+ " prevent.dce" (%result ) : (vector <[2 ]x[2 ]xf64 >) -> ()
656
636
}
657
637
658
638
// -----
@@ -674,3 +654,23 @@ func.func @vector_outerproduct_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xb
674
654
%result = vector.outerproduct %lhs , %rhs , %acc {kind = #vector.kind <add >} : vector <[8 ]xbf16 >, vector <[8 ]xbf16 >
675
655
" prevent.dce" (%result ) : (vector <[8 ]x[8 ]xbf16 >) -> ()
676
656
}
657
+
658
+ // -----
659
+
660
+ // CHECK-LABEL: @vector_outerproduct_f32
661
+ // CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>
662
+ func.func @vector_outerproduct_f32 (%lhs : vector <[4 ]xf32 >, %rhs : vector <[4 ]xf32 >, %acc : vector <[4 ]x[4 ]xf32 >) {
663
+ // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>
664
+ %result = vector.outerproduct %lhs , %rhs , %acc {kind = #vector.kind <add >} : vector <[4 ]xf32 >, vector <[4 ]xf32 >
665
+ " prevent.dce" (%result ) : (vector <[4 ]x[4 ]xf32 >) -> ()
666
+ }
667
+
668
+ // -----
669
+
670
+ // CHECK-LABEL: @vector_outerproduct_f64
671
+ // CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>
672
+ func.func @vector_outerproduct_f64 (%lhs : vector <[2 ]xf64 >, %rhs : vector <[2 ]xf64 >, %acc : vector <[2 ]x[2 ]xf64 >) {
673
+ // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>
674
+ %result = vector.outerproduct %lhs , %rhs , %acc {kind = #vector.kind <add >} : vector <[2 ]xf64 >, vector <[2 ]xf64 >
675
+ " prevent.dce" (%result ) : (vector <[2 ]x[2 ]xf64 >) -> ()
676
+ }
0 commit comments