@@ -282,51 +282,38 @@ class OuterProductFusion4Way
282
282
283
283
LogicalResult matchAndRewrite (arm_sme::OuterProductOp op,
284
284
PatternRewriter &rewriter) const override {
285
- Value acc = op.getAcc ();
286
- if (!acc)
287
- return rewriter.notifyMatchFailure (op, MATCH_FAILURE_NO_ACCUMULATOR);
288
-
289
- arm_sme::OuterProductOp op4 = op;
290
- arm_sme::OuterProductOp op3 = acc.getDefiningOp <arm_sme::OuterProductOp>();
291
- if (!op3)
292
- return rewriter.notifyMatchFailure (
293
- op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
294
-
295
- acc = op3.getAcc ();
296
- if (!acc)
297
- return rewriter.notifyMatchFailure (op, MATCH_FAILURE_NO_ACCUMULATOR);
298
-
299
- arm_sme::OuterProductOp op2 = acc.getDefiningOp <arm_sme::OuterProductOp>();
300
- if (!op2)
301
- return rewriter.notifyMatchFailure (
302
- op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
303
-
304
- acc = op2.getAcc ();
305
- if (!acc)
306
- return rewriter.notifyMatchFailure (op, MATCH_FAILURE_NO_ACCUMULATOR);
307
-
308
- arm_sme::OuterProductOp op1 = acc.getDefiningOp <arm_sme::OuterProductOp>();
309
- if (!op1)
310
- return rewriter.notifyMatchFailure (
311
- op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
312
-
313
- arm_sme::CombiningKind kind = op1.getKind ();
314
- if (op2.getKind () != kind || op3.getKind () != kind || op4.getKind () != kind)
315
- return rewriter.notifyMatchFailure (
316
- op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
317
-
318
- if (!op1->hasOneUse () || !op2->hasOneUse () || !op3->hasOneUse ())
319
- return rewriter.notifyMatchFailure (
320
- op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
321
-
322
- if (bool (op1.getLhsMask ()) != bool (op2.getLhsMask ()) !=
323
- bool (op3.getLhsMask ()) != bool (op4.getLhsMask ()))
324
- return rewriter.notifyMatchFailure (op,
325
- MATCH_FAILURE_INCONSISTENT_MASKING);
285
+ SmallVector<arm_sme::OuterProductOp, 4 > outerProductChain;
286
+ outerProductChain.push_back (op);
287
+
288
+ for (int i = 0 ; i < 3 ; ++i) {
289
+ auto currentOp = outerProductChain.back ();
290
+ auto acc = currentOp.getAcc ();
291
+ if (!acc)
292
+ return rewriter.notifyMatchFailure (op, MATCH_FAILURE_NO_ACCUMULATOR);
293
+ auto previousOp = acc.getDefiningOp <arm_sme::OuterProductOp>();
294
+ if (!previousOp)
295
+ return rewriter.notifyMatchFailure (
296
+ op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
297
+ if (!previousOp->hasOneUse ())
298
+ return rewriter.notifyMatchFailure (
299
+ op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
300
+ if (previousOp.getKind () != currentOp.getKind ())
301
+ return rewriter.notifyMatchFailure (
302
+ op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
303
+ if (bool (previousOp.getLhsMask ()) != bool (currentOp.getLhsMask ()))
304
+ return rewriter.notifyMatchFailure (
305
+ op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
306
+ outerProductChain.push_back (previousOp);
307
+ }
326
308
327
- if (failed (canFuseOuterProducts (rewriter, op1, op2, op3, op4 )))
309
+ if (failed (canFuseOuterProducts (rewriter, outerProductChain )))
328
310
return failure ();
329
311
312
+ arm_sme::OuterProductOp op1 = outerProductChain[3 ];
313
+ arm_sme::OuterProductOp op2 = outerProductChain[2 ];
314
+ arm_sme::OuterProductOp op3 = outerProductChain[1 ];
315
+ arm_sme::OuterProductOp op4 = outerProductChain[0 ];
316
+
330
317
auto loc = op.getLoc ();
331
318
332
319
auto packInputs = [&](Value lhs, Value rhs) {
@@ -364,6 +351,7 @@ class OuterProductFusion4Way
364
351
auto lhsExtOp = op.getLhs ().getDefiningOp ();
365
352
auto rhsExtOp = op.getRhs ().getDefiningOp ();
366
353
354
+ arm_sme::CombiningKind kind = op.getKind ();
367
355
if (kind == arm_sme::CombiningKind::Add) {
368
356
if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
369
357
rewriter.replaceOpWithNewOp <arm_sme::SMopa4WayOp>(
@@ -414,94 +402,41 @@ class OuterProductFusion4Way
414
402
// - a floating-point extension for floating-point types.
415
403
// - the types and extension are supported, i.e. there's a 4-way operation
416
404
// they can be fused into.
417
- LogicalResult canFuseOuterProducts (PatternRewriter &rewriter,
418
- arm_sme::OuterProductOp op1,
419
- arm_sme::OuterProductOp op2,
420
- arm_sme::OuterProductOp op3,
421
- arm_sme::OuterProductOp op4) const {
405
+ LogicalResult
406
+ canFuseOuterProducts (PatternRewriter &rewriter,
407
+ SmallVectorImpl<arm_sme::OuterProductOp> &ops) const {
422
408
// Supported result types.
423
409
auto nxnxv4i32 =
424
410
VectorType::get ({4 , 4 }, rewriter.getI32Type (), {true , true });
425
411
auto nxnxv2i64 =
426
412
VectorType::get ({2 , 2 }, rewriter.getI64Type (), {true , true });
413
+
427
414
// Supported input types.
428
415
// Note: this is before packing so these have 1/4 the number of elements
429
416
// of the input vector types of the 4-way operations.
430
417
auto nxv4i8 = VectorType::get ({4 }, rewriter.getI8Type (), true );
431
418
auto nxv2i16 = VectorType::get ({2 }, rewriter.getI16Type (), true );
432
- if (
433
- // signed, i8i8i32
434
- (failed (
435
- isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
436
- failed (
437
- isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
438
- failed (
439
- isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
440
- failed (
441
- isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
442
- // signed, i16i16i64
443
- (failed (
444
- isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
445
- failed (
446
- isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
447
- failed (
448
- isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
449
- failed (isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv2i64,
450
- nxv2i16))) &&
451
- // unsigned, i8i8i32
452
- (failed (
453
- isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
454
- failed (
455
- isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
456
- failed (
457
- isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
458
- failed (
459
- isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
460
- // unsigned, i16i16i64
461
- (failed (
462
- isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
463
- failed (
464
- isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
465
- failed (
466
- isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
467
- failed (isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv2i64,
468
- nxv2i16))) &&
469
- // signed by unsigned, i8i8i32
470
- (failed (isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
471
- rewriter, op1, nxnxv4i32, nxv4i8)) ||
472
- failed (isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
473
- rewriter, op2, nxnxv4i32, nxv4i8)) ||
474
- failed (isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
475
- rewriter, op3, nxnxv4i32, nxv4i8)) ||
476
- failed (isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
477
- rewriter, op4, nxnxv4i32, nxv4i8))) &&
478
- // signed by unsigned, i16i16i64
479
- (failed (isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
480
- rewriter, op1, nxnxv2i64, nxv2i16)) ||
481
- failed (isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
482
- rewriter, op2, nxnxv2i64, nxv2i16)) ||
483
- failed (isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
484
- rewriter, op3, nxnxv2i64, nxv2i16)) ||
485
- failed (isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
486
- rewriter, op4, nxnxv2i64, nxv2i16))) &&
487
- // unsigned by signed, i8i8i32
488
- (failed (isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
489
- rewriter, op1, nxnxv4i32, nxv4i8)) ||
490
- failed (isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
491
- rewriter, op2, nxnxv4i32, nxv4i8)) ||
492
- failed (isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
493
- rewriter, op3, nxnxv4i32, nxv4i8)) ||
494
- failed (isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
495
- rewriter, op4, nxnxv4i32, nxv4i8))) &&
496
- // unsigned by signed, i16i16i64
497
- (failed (isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
498
- rewriter, op1, nxnxv2i64, nxv2i16)) ||
499
- failed (isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
500
- rewriter, op2, nxnxv2i64, nxv2i16)) ||
501
- failed (isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
502
- rewriter, op3, nxnxv2i64, nxv2i16)) ||
503
- failed (isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
504
- rewriter, op4, nxnxv2i64, nxv2i16))))
419
+
420
+ auto failedToMatch = [&](VectorType resultType, VectorType inputType,
421
+ auto lhsExtendOp, auto rhsExtendOp) {
422
+ using LhsExtendOpTy = decltype (lhsExtendOp);
423
+ using RhsExtendOpTy = decltype (rhsExtendOp);
424
+ for (auto op : ops) {
425
+ if (failed (isCompatible<LhsExtendOpTy, RhsExtendOpTy>(
426
+ rewriter, op, resultType, inputType)))
427
+ return true ;
428
+ }
429
+ return false ;
430
+ };
431
+
432
+ if (failedToMatch (nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
433
+ failedToMatch (nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
434
+ failedToMatch (nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
435
+ failedToMatch (nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) &&
436
+ failedToMatch (nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
437
+ failedToMatch (nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
438
+ failedToMatch (nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
439
+ failedToMatch (nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{}))
505
440
return failure ();
506
441
507
442
return success ();
0 commit comments