Skip to content

Commit 4ebcc9d

Browse files
committed
Address comments
add comment to clarify each variant.
1 parent 77f7fc3 commit 4ebcc9d

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -353,35 +353,51 @@ class OuterProductFusion4Way
353353

354354
arm_sme::CombiningKind kind = op.getKind();
355355
if (kind == arm_sme::CombiningKind::Add) {
356-
if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
356+
if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
357+
// signed
357358
rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
358359
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
359-
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
360+
} else if (isa<arith::ExtUIOp>(lhsExtOp) &&
361+
isa<arith::ExtUIOp>(rhsExtOp)) {
362+
// unsigned
360363
rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
361364
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
362-
else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
365+
} else if (isa<arith::ExtSIOp>(lhsExtOp) &&
366+
isa<arith::ExtUIOp>(rhsExtOp)) {
367+
// signed by unsigned
363368
rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
364369
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
365-
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
370+
} else if (isa<arith::ExtUIOp>(lhsExtOp) &&
371+
isa<arith::ExtSIOp>(rhsExtOp)) {
372+
// unsigned by signed
366373
rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
367374
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
368-
else
375+
} else {
369376
llvm_unreachable("unexpected extend op!");
377+
}
370378
} else if (kind == arm_sme::CombiningKind::Sub) {
371-
if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
379+
if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
380+
// signed
372381
rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
373382
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
374-
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
383+
} else if (isa<arith::ExtUIOp>(lhsExtOp) &&
384+
isa<arith::ExtUIOp>(rhsExtOp)) {
385+
// unsigned
375386
rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
376387
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
377-
else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
388+
} else if (isa<arith::ExtSIOp>(lhsExtOp) &&
389+
isa<arith::ExtUIOp>(rhsExtOp)) {
390+
// signed by unsigned
378391
rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
379392
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
380-
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
393+
} else if (isa<arith::ExtUIOp>(lhsExtOp) &&
394+
isa<arith::ExtSIOp>(rhsExtOp)) {
395+
// unsigned by signed
381396
rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
382397
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
383-
else
398+
} else {
384399
llvm_unreachable("unexpected extend op!");
400+
}
385401
} else {
386402
llvm_unreachable("unexpected arm_sme::CombiningKind!");
387403
}

0 commit comments

Comments
 (0)