Skip to content

Commit e7432ba

Browse files
authored
[mlir][ArmSME] Fail instead of error in vector.outerproduct lowering (#75447)
The 'vector.outerproduct' -> 'arm_sme.outerproduct' conversion currently errors on unsupported cases when it should return failure.
1 parent c532ba4 commit e7432ba

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,16 +510,18 @@ struct VectorOuterProductToArmSMELowering
510510
// We don't yet support lowering AXPY operations to SME. These could be
511511
// lowered by masking out all but the first element of the LHS.
512512
if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
513-
return outerProductOp.emitError("AXPY operations not supported");
513+
return rewriter.notifyMatchFailure(outerProductOp,
514+
"AXPY operations not supported");
514515

515516
if (!arm_sme::isValidSMETileVectorType(
516517
outerProductOp.getResultVectorType()))
517-
return outerProductOp.emitError(
518-
"outer product does not fit into SME tile");
518+
return rewriter.notifyMatchFailure(
519+
outerProductOp, "outer product does not fit into SME tile");
519520

520521
auto kind = outerProductOp.getKind();
521522
if (kind != vector::CombiningKind::ADD)
522-
return outerProductOp.emitError(
523+
return rewriter.notifyMatchFailure(
524+
outerProductOp,
523525
"unsupported kind (lowering to SME only supports ADD at the moment)");
524526

525527
Value lhsMask = {};

mlir/test/Conversion/VectorToArmSME/unsupported.mlir

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,25 +151,31 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
151151

152152
// -----
153153

154+
// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
155+
// CHECK-NOT: arm_sme.outerproduct
156+
// CHECK: vector.outerproduct
154157
func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
155-
// expected-error@+1 {{AXPY operations not supported}}
156158
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
157159
return %0 : vector<[2]xf64>
158160
}
159161

160162
// -----
161163

164+
// CHECK-LABEL: @vector_outerproduct_unsupported_kind
165+
// CHECK-NOT: arm_sme.outerproduct
166+
// CHECK: vector.outerproduct
162167
func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
163168
%acc = arm_sme.get_tile : vector<[2]x[2]xf64>
164-
// expected-error@+1 {{unsupported kind}}
165169
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
166170
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
167171
}
168172

169173
// -----
170174

175+
// CHECK-LABEL: @vector_outerproduct_unknown_mask
176+
// CHECK-NOT: arm_sme.outerproduct
177+
// CHECK: vector.outerproduct
171178
func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
172-
// CHECK: vector.outerproduct
173179
%acc = arm_sme.get_tile : vector<[4]x[4]xf32>
174180
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
175181
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()

0 commit comments

Comments
 (0)