@@ -353,35 +353,51 @@ class OuterProductFusion4Way
353
353
354
354
arm_sme::CombiningKind kind = op.getKind ();
355
355
if (kind == arm_sme::CombiningKind::Add) {
356
- if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
356
+ if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
357
+ // signed
357
358
rewriter.replaceOpWithNewOp <arm_sme::SMopa4WayOp>(
358
359
op4, op.getResultType (), lhs, rhs, lhsMask, rhsMask, op1.getAcc ());
359
- else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
360
+ } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
361
+ isa<arith::ExtUIOp>(rhsExtOp)) {
362
+ // unsigned
360
363
rewriter.replaceOpWithNewOp <arm_sme::UMopa4WayOp>(
361
364
op4, op.getResultType (), lhs, rhs, lhsMask, rhsMask, op1.getAcc ());
362
- else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
365
+ } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
366
+ isa<arith::ExtUIOp>(rhsExtOp)) {
367
+ // signed by unsigned
363
368
rewriter.replaceOpWithNewOp <arm_sme::SuMopa4WayOp>(
364
369
op4, op.getResultType (), lhs, rhs, lhsMask, rhsMask, op1.getAcc ());
365
- else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
370
+ } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
371
+ isa<arith::ExtSIOp>(rhsExtOp)) {
372
+ // unsigned by signed
366
373
rewriter.replaceOpWithNewOp <arm_sme::UsMopa4WayOp>(
367
374
op4, op.getResultType (), lhs, rhs, lhsMask, rhsMask, op1.getAcc ());
368
- else
375
+ } else {
369
376
llvm_unreachable (" unexpected extend op!" );
377
+ }
370
378
} else if (kind == arm_sme::CombiningKind::Sub) {
371
- if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
379
+ if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
380
+ // signed
372
381
rewriter.replaceOpWithNewOp <arm_sme::SMops4WayOp>(
373
382
op4, op.getResultType (), lhs, rhs, lhsMask, rhsMask, op1.getAcc ());
374
- else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
383
+ } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
384
+ isa<arith::ExtUIOp>(rhsExtOp)) {
385
+ // unsigned
375
386
rewriter.replaceOpWithNewOp <arm_sme::UMops4WayOp>(
376
387
op4, op.getResultType (), lhs, rhs, lhsMask, rhsMask, op1.getAcc ());
377
- else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
388
+ } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
389
+ isa<arith::ExtUIOp>(rhsExtOp)) {
390
+ // signed by unsigned
378
391
rewriter.replaceOpWithNewOp <arm_sme::SuMops4WayOp>(
379
392
op4, op.getResultType (), lhs, rhs, lhsMask, rhsMask, op1.getAcc ());
380
- else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
393
+ } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
394
+ isa<arith::ExtSIOp>(rhsExtOp)) {
395
+ // unsigned by signed
381
396
rewriter.replaceOpWithNewOp <arm_sme::UsMops4WayOp>(
382
397
op4, op.getResultType (), lhs, rhs, lhsMask, rhsMask, op1.getAcc ());
383
- else
398
+ } else {
384
399
llvm_unreachable (" unexpected extend op!" );
400
+ }
385
401
} else {
386
402
llvm_unreachable (" unexpected arm_sme::CombiningKind!" );
387
403
}
0 commit comments