Skip to content

Commit 452f4a1

Browse files
committed
Fixups
1 parent a4d95a4 commit 452f4a1

File tree

6 files changed

+58
-71
lines changed

6 files changed

+58
-71
lines changed

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ struct TransposeOpToArmSMELowering
429429

430430
/// Conversion pattern for vector.outerproduct.
431431
///
432-
/// If the vector.outerproduct is masked (and the mask from a
432+
/// If the vector.outerproduct is masked (and the mask is from a
433433
/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
434434
/// operands.
435435
///

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,11 @@ struct MoveTileSliceToVectorArmSMELowering
460460
}
461461
};
462462

463-
/// Lower `vector.outerproduct` to SME MOPA intrinsics.
463+
/// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
464464
///
465465
/// Example:
466466
///
467-
/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
467+
/// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
468468
/// : vector<[4]xf32>, vector<[4]xf32>
469469
///
470470
/// is converted to:
@@ -474,7 +474,7 @@ struct MoveTileSliceToVectorArmSMELowering
474474
/// vector<[4]xf32>) -> ()
475475
///
476476
/// Currently only supports FMOPA and BFMOPA (non-widening).
477-
struct OuterProductToArmSMELowering
477+
struct OuterProductOpConversion
478478
: public ConvertOpToLLVMPattern<arm_sme::OuterProductOp> {
479479
using ConvertOpToLLVMPattern<arm_sme::OuterProductOp>::ConvertOpToLLVMPattern;
480480

@@ -725,6 +725,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
725725
patterns.add<
726726
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
727727
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
728-
OuterProductToArmSMELowering, ZeroOpConversion,
729-
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
728+
OuterProductOpConversion, ZeroOpConversion, VectorExtractToArmSMELowering,
729+
VectorInsertToArmSMELowering>(converter);
730730
}

mlir/test/Dialect/ArmSME/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
157157

158158
// -----
159159

160-
func.func @arm_sme_outproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
160+
func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
161161
{
162162
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
163163
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
@@ -166,7 +166,7 @@ func.func @arm_sme_outproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: ve
166166

167167
// -----
168168

169-
func.func @arm_sme_outproduct__bad_vector_type(%vecA: vector<[4]xf32>, %vecB: vector<[8]xf32>) -> vector<[4]x[4]xf32>
169+
func.func @arm_sme_outerproduct__bad_vector_type(%vecA: vector<[4]xf32>, %vecB: vector<[8]xf32>) -> vector<[4]x[4]xf32>
170170
{
171171
// expected-error@+1 {{op failed to verify that all of {lhs, rhs} have same type}}
172172
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[4]xf32>, vector<[8]xf32>

mlir/test/Dialect/ArmSME/roundtrip.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,39 +1168,39 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>
11681168

11691169
// -----
11701170

1171-
func.func @arm_sme_outproduct(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[8]x[8]xi16> {
1171+
func.func @arm_sme_outerproduct(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[8]x[8]xi16> {
11721172
// CHECK: arm_sme.outerproduct {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16>
11731173
%result = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>
11741174
return %result : vector<[8]x[8]xi16>
11751175
}
11761176

11771177
// -----
11781178

1179-
func.func @arm_sme_outproduct_with_masking(%vecA: vector<[4]xf32>, %vecB: vector<[4]xf32>, %maskA: vector<[4]xi1>, %maskB: vector<[4]xi1>) -> vector<[4]x[4]xf32> {
1179+
func.func @arm_sme_outerproduct_with_masking(%vecA: vector<[4]xf32>, %vecB: vector<[4]xf32>, %maskA: vector<[4]xi1>, %maskB: vector<[4]xi1>) -> vector<[4]x[4]xf32> {
11801180
// CHECK: arm_sme.outerproduct {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<[4]xf32>, vector<[4]xf32>
11811181
%result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<[4]xf32>, vector<[4]xf32>
11821182
return %result : vector<[4]x[4]xf32>
11831183
}
11841184

11851185
// -----
11861186

1187-
func.func @arm_sme_outproduct_with_acc(%vecA: vector<[2]xi64>, %vecB: vector<[2]xi64>, %acc: vector<[2]x[2]xi64>) -> vector<[2]x[2]xi64> {
1187+
func.func @arm_sme_outerproduct_with_acc(%vecA: vector<[2]xi64>, %vecB: vector<[2]xi64>, %acc: vector<[2]x[2]xi64>) -> vector<[2]x[2]xi64> {
11881188
// CHECK: arm_sme.outerproduct {{.*}}, {{.*}} acc({{.*}}) : vector<[2]xi64>, vector<[2]xi64>
11891189
%result = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<[2]xi64>, vector<[2]xi64>
11901190
return %result : vector<[2]x[2]xi64>
11911191
}
11921192

11931193
// -----
11941194

1195-
func.func @arm_sme_outproduct_with_kind(%vecA: vector<[2]xf64>, %vecB: vector<[2]xf64>) -> vector<[2]x[2]xf64> {
1195+
func.func @arm_sme_outerproduct_with_kind(%vecA: vector<[2]xf64>, %vecB: vector<[2]xf64>) -> vector<[2]x[2]xf64> {
11961196
// CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> : vector<[2]xf64>, vector<[2]xf64>
11971197
%result = arm_sme.outerproduct %vecA, %vecB kind<sub> : vector<[2]xf64>, vector<[2]xf64>
11981198
return %result : vector<[2]x[2]xf64>
11991199
}
12001200

12011201
// -----
12021202

1203-
func.func @arm_sme_outproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>, %acc: vector<[16]x[16]xi8>, %maskA: vector<[16]xi1>, %maskB: vector<[16]xi1>) -> vector<[16]x[16]xi8> {
1203+
func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>, %acc: vector<[16]x[16]xi8>, %maskA: vector<[16]xi1>, %maskB: vector<[16]xi1>) -> vector<[16]x[16]xi8> {
12041204
// CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> acc({{.*}}) masks({{.*}}, {{.*}}) : vector<[16]xi8>, vector<[16]xi8>
12051205
%result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
12061206
return %result : vector<[16]x[16]xi8>

mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
465465

466466
// CHECK-LABEL: @vector_outerproduct_masked_f32
467467
// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
468-
func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0: index, %dim1: index) {
468+
func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
469469
// CHECK: %[[DIM0_I32:.*]] = arith.index_cast %[[DIM0]] : index to i32
470470
// CHECK: %[[INSERT_DIM0:.*]] = llvm.insertelement %[[DIM0_I32]], {{.*}} : vector<[4]xi32>
471471
// CHECK: %[[SPLAT_DIM0:.*]] = llvm.shufflevector %[[INSERT_DIM0]], {{.*}} : vector<[4]xi32>
@@ -483,22 +483,9 @@ func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<
483483

484484
// -----
485485

486-
// CHECK-LABEL: @vector_outerproduct_masked_f64
487-
// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
488-
func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
489-
// CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
490-
// CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
491-
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
492-
%mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
493-
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
494-
"prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
495-
}
496-
497-
// -----
498-
499486
// CHECK-LABEL: @vector_outerproduct_masked_f16
500487
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
501-
func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0: index, %dim1: index) {
488+
func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
502489
// CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
503490
// CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
504491
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
@@ -511,7 +498,7 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
511498

512499
// CHECK-LABEL: @vector_outerproduct_masked_bf16
513500
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
514-
func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0: index, %dim1: index) {
501+
func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
515502
// CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
516503
// CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
517504
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
@@ -522,9 +509,9 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
522509

523510
// -----
524511

525-
// CHECK-LABEL: @vector_outerproduct_masked_f16
512+
// CHECK-LABEL: @vector_outerproduct_masked_f64
526513
// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
527-
func.func @vector_outerproduct_masked_f16(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
514+
func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
528515
// CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
529516
// CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
530517
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)

mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -585,35 +585,9 @@ func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
585585

586586
// -----
587587

588-
// CHECK-LABEL: @vector_outerproduct_masked_f64
589-
// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
590-
func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
591-
%mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
592-
// CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[2]xi1>
593-
// CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[2]xi1>
594-
// CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>
595-
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
596-
"prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
597-
}
598-
599-
// -----
600-
601-
// CHECK-LABEL: @vector_outerproduct_masked_f32
602-
// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
603-
func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0: index, %dim1: index) {
604-
%mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
605-
// CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
606-
// CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
607-
// CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>
608-
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
609-
"prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
610-
}
611-
612-
// -----
613-
614588
// CHECK-LABEL: @vector_outerproduct_masked_f16
615589
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
616-
func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0: index, %dim1: index) {
590+
func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
617591
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
618592
// CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
619593
// CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
@@ -626,7 +600,7 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
626600

627601
// CHECK-LABEL: @vector_outerproduct_masked_bf16
628602
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
629-
func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0: index, %dim1: index) {
603+
func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
630604
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
631605
// CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
632606
// CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
@@ -637,22 +611,28 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
637611

638612
// -----
639613

640-
// CHECK-LABEL: @vector_outerproduct_f64
641-
// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>
642-
func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
643-
// CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>
644-
%result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
645-
"prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
614+
// CHECK-LABEL: @vector_outerproduct_masked_f32
615+
// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
616+
func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
617+
%mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
618+
// CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
619+
// CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
620+
// CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>
621+
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
622+
"prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
646623
}
647624

648625
// -----
649626

650-
// CHECK-LABEL: @vector_outerproduct_f32
651-
// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>
652-
func.func @vector_outerproduct_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
653-
// CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>
654-
%result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
655-
"prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
627+
// CHECK-LABEL: @vector_outerproduct_masked_f64
628+
// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
629+
func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
630+
%mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
631+
// CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[2]xi1>
632+
// CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[2]xi1>
633+
// CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>
634+
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
635+
"prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
656636
}
657637

658638
// -----
@@ -674,3 +654,23 @@ func.func @vector_outerproduct_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xb
674654
%result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
675655
"prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
676656
}
657+
658+
// -----
659+
660+
// CHECK-LABEL: @vector_outerproduct_f32
661+
// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>
662+
func.func @vector_outerproduct_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
663+
// CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>
664+
%result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
665+
"prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
666+
}
667+
668+
// -----
669+
670+
// CHECK-LABEL: @vector_outerproduct_f64
671+
// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>
672+
func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
673+
// CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>
674+
%result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
675+
"prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
676+
}

0 commit comments

Comments
 (0)