Skip to content

Commit 7e175b3

Browse files
authored
[mlir][vector] Add more tests for ConvertVectorToLLVM (2/n) (#102203)
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass. Covers the following Ops: * vector.outerproduct
1 parent 50a2b31 commit 7e175b3

File tree

1 file changed

+180
-0
lines changed

1 file changed

+180
-0
lines changed

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,30 @@ func.func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x
600600
// CHECK: %[[T14:.*]] = builtin.unrealized_conversion_cast %[[T13]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
601601
// CHECK: return %[[T14]] : vector<2x3xf32>
602602

603+
func.func @outerproduct_scalable(%arg0: vector<2xf32>, %arg1: vector<[3]xf32>) -> vector<2x[3]xf32> {
604+
%2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<[3]xf32>
605+
return %2 : vector<2x[3]xf32>
606+
}
607+
// CHECK-LABEL: @outerproduct_scalable
608+
// CHECK-SAME: %[[A:.*]]: vector<2xf32>,
609+
// CHECK-SAME: %[[B:.*]]: vector<[3]xf32>)
610+
// CHECK: %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32>
611+
// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2x[3]xf32> to !llvm.array<2 x vector<[3]xf32>>
612+
// CHECK: %[[T3:.*]] = llvm.mlir.constant(0 : i64) : i64
613+
// CHECK: %[[T4:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T3]] : i64] : vector<2xf32>
614+
// CHECK: %[[T5Insert:.*]] = llvm.insertelement %[[T4]]
615+
// CHECK: %[[T5:.*]] = llvm.shufflevector %[[T5Insert]]
616+
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<[3]xf32>
617+
// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][0] : !llvm.array<2 x vector<[3]xf32>>
618+
// CHECK: %[[T9:.*]] = llvm.mlir.constant(1 : i64) : i64
619+
// CHECK: %[[T10:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T9]] : i64] : vector<2xf32>
620+
// CHECK: %[[T11Insert:.*]] = llvm.insertelement %[[T10]]
621+
// CHECK: %[[T11:.*]] = llvm.shufflevector %[[T11Insert]]
622+
// CHECK: %[[T12:.*]] = arith.mulf %[[T11]], %[[B]] : vector<[3]xf32>
623+
// CHECK: %[[T13:.*]] = llvm.insertvalue %[[T12]], %[[T8]][1] : !llvm.array<2 x vector<[3]xf32>>
624+
// CHECK: %[[T14:.*]] = builtin.unrealized_conversion_cast %[[T13]] : !llvm.array<2 x vector<[3]xf32>> to vector<2x[3]xf32>
625+
// CHECK: return %[[T14]] : vector<2x[3]xf32>
626+
603627
// -----
604628

605629
func.func @outerproduct_index(%arg0: vector<2xindex>, %arg1: vector<3xindex>) -> vector<2x3xindex> {
@@ -621,6 +645,25 @@ func.func @outerproduct_index(%arg0: vector<2xindex>, %arg1: vector<3xindex>) ->
621645
// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T6]] : vector<3xindex> to vector<3xi64>
622646
// CHECK: %{{.*}} = llvm.insertvalue %[[T7]], %[[T8]][0] : !llvm.array<2 x vector<3xi64>>
623647

648+
func.func @outerproduct_index_scalable(%arg0: vector<2xindex>, %arg1: vector<[3]xindex>) -> vector<2x[3]xindex> {
649+
%2 = vector.outerproduct %arg0, %arg1 : vector<2xindex>, vector<[3]xindex>
650+
return %2 : vector<2x[3]xindex>
651+
}
652+
// CHECK-LABEL: @outerproduct_index_scalable
653+
// CHECK-SAME: %[[A:.*]]: vector<2xindex>,
654+
// CHECK-SAME: %[[B:.*]]: vector<[3]xindex>)
655+
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64>
656+
// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<2x[3]xindex>
657+
// CHECK: %[[T8:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<2x[3]xindex> to !llvm.array<2 x vector<[3]xi64>>
658+
// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : i64
659+
// CHECK: %[[T3:.*]] = llvm.extractelement %[[T1]]{{\[}}%[[T2]] : i64] : vector<2xi64>
660+
// CHECK: %[[T4:.*]] = llvm.insertelement %[[T3]]
661+
// CHECK: %[[T5:.*]] = llvm.shufflevector %[[T4]]
662+
// CHECK: %[[T5Cast:.*]] = builtin.unrealized_conversion_cast %[[T5]] : vector<[3]xi64> to vector<[3]xindex>
663+
// CHECK: %[[T6:.*]] = arith.muli %[[T5Cast]], %[[B]] : vector<[3]xindex>
664+
// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T6]] : vector<[3]xindex> to vector<[3]xi64>
665+
// CHECK: %{{.*}} = llvm.insertvalue %[[T7]], %[[T8]][0] : !llvm.array<2 x vector<[3]xi64>>
666+
624667
// -----
625668

626669
func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
@@ -651,6 +694,34 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v
651694
// CHECK: %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
652695
// CHECK: return %[[T19]] : vector<2x3xf32>
653696

697+
func.func @outerproduct_add_scalable(%arg0: vector<2xf32>, %arg1: vector<[3]xf32>, %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
698+
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<[3]xf32>
699+
return %2 : vector<2x[3]xf32>
700+
}
701+
// CHECK-LABEL: @outerproduct_add_scalable
702+
// CHECK-SAME: %[[A:.*]]: vector<2xf32>,
703+
// CHECK-SAME: %[[B:.*]]: vector<[3]xf32>,
704+
// CHECK-SAME: %[[C:.*]]: vector<2x[3]xf32>) -> vector<2x[3]xf32>
705+
// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[C]] : vector<2x[3]xf32> to !llvm.array<2 x vector<[3]xf32>>
706+
// CHECK: %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32>
707+
// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[3]xf32> to !llvm.array<2 x vector<[3]xf32>>
708+
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
709+
// CHECK: %[[T5:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T4]] : i64] : vector<2xf32>
710+
// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
711+
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
712+
// CHECK: %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<[3]xf32>>
713+
// CHECK: %[[T9:.*]] = llvm.intr.fmuladd(%[[T6]], %[[B]], %[[T8]]) : (vector<[3]xf32>, vector<[3]xf32>, vector<[3]xf32>) -> vector<[3]xf32>
714+
// CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][0] : !llvm.array<2 x vector<[3]xf32>>
715+
// CHECK: %[[T12:.*]] = llvm.mlir.constant(1 : i64) : i64
716+
// CHECK: %[[T13:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T12]] : i64] : vector<2xf32>
717+
// CHECK: %[[T14Insert:.*]] = llvm.insertelement %[[T13]]
718+
// CHECK: %[[T14:.*]] = llvm.shufflevector %[[T14Insert]]
719+
// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<[3]xf32>>
720+
// CHECK: %[[T17:.*]] = llvm.intr.fmuladd(%[[T14]], %[[B]], %[[T16]]) : (vector<[3]xf32>, vector<[3]xf32>, vector<[3]xf32>) -> vector<[3]xf32>
721+
// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T11]][1] : !llvm.array<2 x vector<[3]xf32>>
722+
// CHECK: %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<[3]xf32>> to vector<2x[3]xf32>
723+
// CHECK: return %[[T19]] : vector<2x[3]xf32>
724+
654725
// -----
655726

656727
func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
@@ -663,6 +734,16 @@ func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
663734
// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32>
664735
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
665736

737+
func.func @masked_float_add_outerprod_scalable(%arg0: vector<[2]xf32>, %arg1: f32, %arg2: vector<[2]xf32>, %m: vector<[2]xi1>) -> vector<[2]xf32> {
738+
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
739+
return %0 : vector<[2]xf32>
740+
}
741+
742+
// CHECK-LABEL: func.func @masked_float_add_outerprod_scalable(
743+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> {
744+
// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<[2]xf32>, vector<[2]xf32>, vector<[2]xf32>) -> vector<[2]xf32>
745+
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xf32>
746+
666747
// -----
667748

668749
func.func @masked_float_mul_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
@@ -676,6 +757,17 @@ func.func @masked_float_mul_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
676757
// CHECK: %[[VAL_9:.*]] = arith.mulf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
677758
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
678759

760+
func.func @masked_float_mul_outerprod_scalable(%arg0: vector<[2]xf32>, %arg1: f32, %arg2: vector<[2]xf32>, %m: vector<[2]xi1>) -> vector<[2]xf32> {
761+
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<mul>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
762+
return %0 : vector<[2]xf32>
763+
}
764+
765+
// CHECK-LABEL: func.func @masked_float_mul_outerprod_scalable(
766+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> {
767+
// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<[2]xf32>
768+
// CHECK: %[[VAL_9:.*]] = arith.mulf %[[VAL_8]], %[[VAL_2]] : vector<[2]xf32>
769+
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xf32>
770+
679771
// -----
680772

681773
func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
@@ -689,6 +781,17 @@ func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
689781
// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
690782
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
691783

784+
func.func @masked_float_max_outerprod_scalable(%arg0: vector<[2]xf32>, %arg1: f32, %arg2: vector<[2]xf32>, %m: vector<[2]xi1>) -> vector<[2]xf32> {
785+
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<maxnumf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
786+
return %0 : vector<[2]xf32>
787+
}
788+
789+
// CHECK-LABEL: func.func @masked_float_max_outerprod_scalable(
790+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> {
791+
// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<[2]xf32>
792+
// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_8]], %[[VAL_2]] : vector<[2]xf32>
793+
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xf32>
794+
692795
// -----
693796

694797
func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
@@ -702,6 +805,17 @@ func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
702805
// CHECK: %[[VAL_9:.*]] = arith.minnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
703806
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
704807

808+
func.func @masked_float_min_outerprod_scalable(%arg0: vector<[2]xf32>, %arg1: f32, %arg2: vector<[2]xf32>, %m: vector<[2]xi1>) -> vector<[2]xf32> {
809+
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<minnumf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
810+
return %0 : vector<[2]xf32>
811+
}
812+
813+
// CHECK-LABEL: func.func @masked_float_min_outerprod_scalable(
814+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> {
815+
// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<[2]xf32>
816+
// CHECK: %[[VAL_9:.*]] = arith.minnumf %[[VAL_8]], %[[VAL_2]] : vector<[2]xf32>
817+
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xf32>
818+
705819
// -----
706820

707821
func.func @masked_int_add_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
@@ -715,6 +829,17 @@ func.func @masked_int_add_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vec
715829
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
716830
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
717831

832+
func.func @masked_int_add_outerprod_scalable(%arg0: vector<[2]xi32>, %arg1: i32, %arg2: vector<[2]xi32>, %m: vector<[2]xi1>) -> vector<[2]xi32> {
833+
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<add>} : vector<[2]xi32>, i32 } : vector<[2]xi1> -> vector<[2]xi32>
834+
return %0 : vector<[2]xi32>
835+
}
836+
837+
// CHECK-LABEL: func.func @masked_int_add_outerprod_scalable(
838+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<[2]xi32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xi32> {
839+
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<[2]xi32>
840+
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_2]] : vector<[2]xi32>
841+
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xi32>
842+
718843
// -----
719844

720845
func.func @masked_int_mul_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
@@ -728,6 +853,17 @@ func.func @masked_int_mul_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vec
728853
// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
729854
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
730855

856+
func.func @masked_int_mul_outerprod_scalable(%arg0: vector<[2]xi32>, %arg1: i32, %arg2: vector<[2]xi32>, %m: vector<[2]xi1>) -> vector<[2]xi32> {
857+
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<mul>} : vector<[2]xi32>, i32 } : vector<[2]xi1> -> vector<[2]xi32>
858+
return %0 : vector<[2]xi32>
859+
}
860+
861+
// CHECK-LABEL: func.func @masked_int_mul_outerprod_scalable(
862+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<[2]xi32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xi32> {
863+
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<[2]xi32>
864+
// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_8]], %[[VAL_2]] : vector<[2]xi32>
865+
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xi32>
866+
731867
// -----
732868

733869
func.func @masked_int_max_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
@@ -741,6 +877,17 @@ func.func @masked_int_max_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vec
741877
// CHECK: %[[VAL_9:.*]] = arith.maxsi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
742878
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
743879

880+
func.func @masked_int_max_outerprod_scalable(%arg0: vector<[2]xi32>, %arg1: i32, %arg2: vector<[2]xi32>, %m: vector<[2]xi1>) -> vector<[2]xi32> {
881+
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<maxsi>} : vector<[2]xi32>, i32 } : vector<[2]xi1> -> vector<[2]xi32>
882+
return %0 : vector<[2]xi32>
883+
}
884+
885+
// CHECK-LABEL: func.func @masked_int_max_outerprod_scalable(
886+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<[2]xi32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xi32> {
887+
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<[2]xi32>
888+
// CHECK: %[[VAL_9:.*]] = arith.maxsi %[[VAL_8]], %[[VAL_2]] : vector<[2]xi32>
889+
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xi32>
890+
744891
// -----
745892

746893
func.func @masked_int_min_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
@@ -754,6 +901,17 @@ func.func @masked_int_min_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vec
754901
// CHECK: %[[VAL_9:.*]] = arith.minui %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
755902
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
756903

904+
func.func @masked_int_min_outerprod_scalable(%arg0: vector<[2]xi32>, %arg1: i32, %arg2: vector<[2]xi32>, %m: vector<[2]xi1>) -> vector<[2]xi32> {
905+
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<minui>} : vector<[2]xi32>, i32 } : vector<[2]xi1> -> vector<[2]xi32>
906+
return %0 : vector<[2]xi32>
907+
}
908+
909+
// CHECK-LABEL: func.func @masked_int_min_outerprod_scalable(
910+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<[2]xi32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xi32> {
911+
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<[2]xi32>
912+
// CHECK: %[[VAL_9:.*]] = arith.minui %[[VAL_8]], %[[VAL_2]] : vector<[2]xi32>
913+
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xi32>
914+
757915
// -----
758916

759917
func.func @masked_int_and_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
@@ -767,6 +925,17 @@ func.func @masked_int_and_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vec
767925
// CHECK: %[[VAL_9:.*]] = arith.andi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
768926
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
769927

928+
func.func @masked_int_and_outerprod_scalable(%arg0: vector<[2]xi32>, %arg1: i32, %arg2: vector<[2]xi32>, %m: vector<[2]xi1>) -> vector<[2]xi32> {
929+
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<and>} : vector<[2]xi32>, i32 } : vector<[2]xi1> -> vector<[2]xi32>
930+
return %0 : vector<[2]xi32>
931+
}
932+
933+
// CHECK-LABEL: func.func @masked_int_and_outerprod_scalable(
934+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<[2]xi32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xi32> {
935+
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<[2]xi32>
936+
// CHECK: %[[VAL_9:.*]] = arith.andi %[[VAL_8]], %[[VAL_2]] : vector<[2]xi32>
937+
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xi32>
938+
770939
// -----
771940

772941
func.func @masked_int_or_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
@@ -780,6 +949,17 @@ func.func @masked_int_or_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vect
780949
// CHECK: %[[VAL_9:.*]] = arith.ori %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
781950
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
782951

952+
func.func @masked_int_or_outerprod_scalable(%arg0: vector<[2]xi32>, %arg1: i32, %arg2: vector<[2]xi32>, %m: vector<[2]xi1>) -> vector<[2]xi32> {
953+
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<or>} : vector<[2]xi32>, i32 } : vector<[2]xi1> -> vector<[2]xi32>
954+
return %0 : vector<[2]xi32>
955+
}
956+
957+
// CHECK-LABEL: func.func @masked_int_or_outerprod_scalable
958+
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<[2]xi32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xi32> {
959+
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<[2]xi32>
960+
// CHECK: %[[VAL_9:.*]] = arith.ori %[[VAL_8]], %[[VAL_2]] : vector<[2]xi32>
961+
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xi32>
962+
783963
// -----
784964

785965
func.func @shuffle_0D_direct(%arg0: vector<f32>) -> vector<3xf32> {

0 commit comments

Comments
 (0)