Skip to content

Commit 77f7fc3

Browse files
committed
Address comments. Changes:
- fix check for consistent masking. - rewrite as loop that walks outer product chain. - use lambda for match check.
1 parent 25af36c commit 77f7fc3

File tree

1 file changed

+55
-120
lines changed

1 file changed

+55
-120
lines changed

mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp

Lines changed: 55 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -282,51 +282,38 @@ class OuterProductFusion4Way
282282

283283
LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
284284
PatternRewriter &rewriter) const override {
285-
Value acc = op.getAcc();
286-
if (!acc)
287-
return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
288-
289-
arm_sme::OuterProductOp op4 = op;
290-
arm_sme::OuterProductOp op3 = acc.getDefiningOp<arm_sme::OuterProductOp>();
291-
if (!op3)
292-
return rewriter.notifyMatchFailure(
293-
op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
294-
295-
acc = op3.getAcc();
296-
if (!acc)
297-
return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
298-
299-
arm_sme::OuterProductOp op2 = acc.getDefiningOp<arm_sme::OuterProductOp>();
300-
if (!op2)
301-
return rewriter.notifyMatchFailure(
302-
op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
303-
304-
acc = op2.getAcc();
305-
if (!acc)
306-
return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
307-
308-
arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
309-
if (!op1)
310-
return rewriter.notifyMatchFailure(
311-
op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
312-
313-
arm_sme::CombiningKind kind = op1.getKind();
314-
if (op2.getKind() != kind || op3.getKind() != kind || op4.getKind() != kind)
315-
return rewriter.notifyMatchFailure(
316-
op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
317-
318-
if (!op1->hasOneUse() || !op2->hasOneUse() || !op3->hasOneUse())
319-
return rewriter.notifyMatchFailure(
320-
op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
321-
322-
if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()) !=
323-
bool(op3.getLhsMask()) != bool(op4.getLhsMask()))
324-
return rewriter.notifyMatchFailure(op,
325-
MATCH_FAILURE_INCONSISTENT_MASKING);
285+
SmallVector<arm_sme::OuterProductOp, 4> outerProductChain;
286+
outerProductChain.push_back(op);
287+
288+
for (int i = 0; i < 3; ++i) {
289+
auto currentOp = outerProductChain.back();
290+
auto acc = currentOp.getAcc();
291+
if (!acc)
292+
return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
293+
auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>();
294+
if (!previousOp)
295+
return rewriter.notifyMatchFailure(
296+
op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
297+
if (!previousOp->hasOneUse())
298+
return rewriter.notifyMatchFailure(
299+
op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
300+
if (previousOp.getKind() != currentOp.getKind())
301+
return rewriter.notifyMatchFailure(
302+
op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
303+
if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
304+
return rewriter.notifyMatchFailure(
305+
op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
306+
outerProductChain.push_back(previousOp);
307+
}
326308

327-
if (failed(canFuseOuterProducts(rewriter, op1, op2, op3, op4)))
309+
if (failed(canFuseOuterProducts(rewriter, outerProductChain)))
328310
return failure();
329311

312+
arm_sme::OuterProductOp op1 = outerProductChain[3];
313+
arm_sme::OuterProductOp op2 = outerProductChain[2];
314+
arm_sme::OuterProductOp op3 = outerProductChain[1];
315+
arm_sme::OuterProductOp op4 = outerProductChain[0];
316+
330317
auto loc = op.getLoc();
331318

332319
auto packInputs = [&](Value lhs, Value rhs) {
@@ -364,6 +351,7 @@ class OuterProductFusion4Way
364351
auto lhsExtOp = op.getLhs().getDefiningOp();
365352
auto rhsExtOp = op.getRhs().getDefiningOp();
366353

354+
arm_sme::CombiningKind kind = op.getKind();
367355
if (kind == arm_sme::CombiningKind::Add) {
368356
if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
369357
rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
@@ -414,94 +402,41 @@ class OuterProductFusion4Way
414402
// - a floating-point extension for floating-point types.
415403
// - the types and extension are supported, i.e. there's a 4-way operation
416404
// they can be fused into.
417-
LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
418-
arm_sme::OuterProductOp op1,
419-
arm_sme::OuterProductOp op2,
420-
arm_sme::OuterProductOp op3,
421-
arm_sme::OuterProductOp op4) const {
405+
LogicalResult
406+
canFuseOuterProducts(PatternRewriter &rewriter,
407+
SmallVectorImpl<arm_sme::OuterProductOp> &ops) const {
422408
// Supported result types.
423409
auto nxnxv4i32 =
424410
VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
425411
auto nxnxv2i64 =
426412
VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
413+
427414
// Supported input types.
428415
// Note: this is before packing so these have 1/4 the number of elements
429416
// of the input vector types of the 4-way operations.
430417
auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
431418
auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
432-
if (
433-
// signed, i8i8i32
434-
(failed(
435-
isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
436-
failed(
437-
isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
438-
failed(
439-
isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
440-
failed(
441-
isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
442-
// signed, i16i16i64
443-
(failed(
444-
isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
445-
failed(
446-
isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
447-
failed(
448-
isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
449-
failed(isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv2i64,
450-
nxv2i16))) &&
451-
// unsigned, i8i8i32
452-
(failed(
453-
isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
454-
failed(
455-
isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
456-
failed(
457-
isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
458-
failed(
459-
isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
460-
// unsigned, i16i16i64
461-
(failed(
462-
isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
463-
failed(
464-
isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
465-
failed(
466-
isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
467-
failed(isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv2i64,
468-
nxv2i16))) &&
469-
// signed by unsigned, i8i8i32
470-
(failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
471-
rewriter, op1, nxnxv4i32, nxv4i8)) ||
472-
failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
473-
rewriter, op2, nxnxv4i32, nxv4i8)) ||
474-
failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
475-
rewriter, op3, nxnxv4i32, nxv4i8)) ||
476-
failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
477-
rewriter, op4, nxnxv4i32, nxv4i8))) &&
478-
// signed by unsigned, i16i16i64
479-
(failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
480-
rewriter, op1, nxnxv2i64, nxv2i16)) ||
481-
failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
482-
rewriter, op2, nxnxv2i64, nxv2i16)) ||
483-
failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
484-
rewriter, op3, nxnxv2i64, nxv2i16)) ||
485-
failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
486-
rewriter, op4, nxnxv2i64, nxv2i16))) &&
487-
// unsigned by signed, i8i8i32
488-
(failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
489-
rewriter, op1, nxnxv4i32, nxv4i8)) ||
490-
failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
491-
rewriter, op2, nxnxv4i32, nxv4i8)) ||
492-
failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
493-
rewriter, op3, nxnxv4i32, nxv4i8)) ||
494-
failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
495-
rewriter, op4, nxnxv4i32, nxv4i8))) &&
496-
// unsigned by signed, i16i16i64
497-
(failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
498-
rewriter, op1, nxnxv2i64, nxv2i16)) ||
499-
failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
500-
rewriter, op2, nxnxv2i64, nxv2i16)) ||
501-
failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
502-
rewriter, op3, nxnxv2i64, nxv2i16)) ||
503-
failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
504-
rewriter, op4, nxnxv2i64, nxv2i16))))
419+
420+
auto failedToMatch = [&](VectorType resultType, VectorType inputType,
421+
auto lhsExtendOp, auto rhsExtendOp) {
422+
using LhsExtendOpTy = decltype(lhsExtendOp);
423+
using RhsExtendOpTy = decltype(rhsExtendOp);
424+
for (auto op : ops) {
425+
if (failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>(
426+
rewriter, op, resultType, inputType)))
427+
return true;
428+
}
429+
return false;
430+
};
431+
432+
if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
433+
failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
434+
failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
435+
failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) &&
436+
failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
437+
failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
438+
failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
439+
failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{}))
505440
return failure();
506441

507442
return success();

0 commit comments

Comments
 (0)