@@ -426,16 +426,8 @@ struct UnrolledOuterProductGenerator
426
426
}
427
427
428
428
FailureOr<Value> outerProd (Value lhs, Value rhs, Value res,
429
- VectorType lhsType, int reductionDim ,
429
+ VectorType lhsType, int reductionSize ,
430
430
std::optional<Value> maybeMask = std::nullopt) {
431
- // Unrolling a scalable dimension would be incorrect - bail out.
432
- if (lhsType.getScalableDims ()[reductionDim])
433
- return failure ();
434
-
435
- int reductionSize = lhsType.getDimSize (reductionDim);
436
- assert (reductionSize > 0 &&
437
- " Reduction dim must be a known static size to allow unrolling" );
438
-
439
431
// Incremental support for masking.
440
432
if (mask && !maybeMask.has_value ())
441
433
return failure ();
@@ -458,49 +450,93 @@ struct UnrolledOuterProductGenerator
458
450
return res;
459
451
}
460
452
453
+ // / Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
454
+ // / dimension `reductionDim`. If the dimension is a scalable dimension,
455
+ // / returns "nullopt".
456
+ std::optional<int64_t > getReductionSize (VectorType vecType,
457
+ int64_t reductionDim) {
458
+ // Cannot unroll scalable dimension.
459
+ if (vecType.getScalableDims ()[reductionDim])
460
+ return std::nullopt;
461
+ int64_t reductionSize = vecType.getDimSize (reductionDim);
462
+ assert (reductionSize > 0 &&
463
+ " Reduction dim must be a known static size to allow unrolling" );
464
+ return reductionSize;
465
+ }
466
+
461
467
// / Two outer parallel, one inner reduction (matmat flavor).
462
468
FailureOr<Value> matmat () {
463
469
if (!iters ({Par (), Par (), Red ()}))
464
470
return failure ();
465
471
// Set up the parallel/reduction structure in the right form.
466
472
AffineExpr m, n, k;
467
473
bindDims (rewriter.getContext (), m, n, k);
468
- Value transposedMask = t (mask, { 2 , 0 , 1 });
474
+
469
475
// Classical row-major matmul: Just permute the lhs.
470
- if (layout ({{m, k}, {k, n}, {m, n}}))
471
- return outerProd (t (lhs), rhs, res, lhsType, /* reductionDim=*/ 1 ,
472
- transposedMask);
476
+ if (layout ({{m, k}, {k, n}, {m, n}})) {
477
+ if (auto reductionSize = getReductionSize (lhsType, 1 )) {
478
+ // Note: `t` creates new IR. It must be nested within this `if` check
479
+ // so that no IR is created when then pattern returns "failure".
480
+ Value tLhs = t (lhs);
481
+ Value tMask = t (mask, {2 , 0 , 1 });
482
+ return outerProd (tLhs, rhs, res, lhsType, *reductionSize, tMask);
483
+ }
484
+ }
473
485
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
474
486
if (layout ({{m, k}, {n, k}, {m, n}})) {
475
- Value tlhs = t (lhs);
476
- return outerProd (tlhs, t (rhs), res, lhsType, /* reductionDim=*/ 1 ,
477
- transposedMask);
487
+ if (auto reductionSize = getReductionSize (lhsType, 1 )) {
488
+ Value tLhs = t (lhs);
489
+ Value tRhs = t (rhs);
490
+ Value tMask = t (mask, {2 , 0 , 1 });
491
+ return outerProd (tLhs, tRhs, res, lhsType, *reductionSize, tMask);
492
+ }
478
493
}
479
494
// No need to permute anything.
480
- if (layout ({{k, m}, {k, n}, {m, n}}))
481
- return outerProd (lhs, rhs, res, lhsType, /* reductionDim=*/ 0 ,
482
- transposedMask);
495
+ if (layout ({{k, m}, {k, n}, {m, n}})) {
496
+ if (auto reductionSize = getReductionSize (lhsType, 0 )) {
497
+ Value tMask = t (mask, {2 , 0 , 1 });
498
+ return outerProd (lhs, rhs, res, lhsType, *reductionSize, tMask);
499
+ }
500
+ }
483
501
// Just permute the rhs.
484
- if (layout ({{k, m}, {n, k}, {m, n}}))
485
- return outerProd (lhs, t (rhs), res, lhsType, /* reductionDim=*/ 0 ,
486
- transposedMask);
502
+ if (layout ({{k, m}, {n, k}, {m, n}})) {
503
+ if (auto reductionSize = getReductionSize (lhsType, 0 )) {
504
+ Value tRhs = t (rhs);
505
+ Value tMask = t (mask, {2 , 0 , 1 });
506
+ return outerProd (lhs, tRhs, res, lhsType, *reductionSize, tMask);
507
+ }
508
+ }
487
509
// Transposed output: swap RHS and LHS.
488
510
// Classical row-major matmul: permute the lhs.
489
- if (layout ({{m, k}, {k, n}, {n, m}}))
490
- return outerProd (rhs, t (lhs), res, lhsType, /* reductionDim=*/ 1 ,
491
- transposedMask);
511
+ if (layout ({{m, k}, {k, n}, {n, m}})) {
512
+ if (auto reductionSize = getReductionSize (lhsType, 1 )) {
513
+ Value tLhs = t (lhs);
514
+ Value tMask = t (mask, {2 , 0 , 1 });
515
+ return outerProd (rhs, tLhs, res, lhsType, *reductionSize, tMask);
516
+ }
517
+ }
492
518
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
493
519
if (layout ({{m, k}, {n, k}, {n, m}})) {
494
- Value trhs = t (rhs);
495
- return outerProd (trhs, t (lhs), res, lhsType, /* reductionDim=*/ 1 ,
496
- transposedMask);
520
+ if (auto reductionSize = getReductionSize (lhsType, 1 )) {
521
+ Value tRhs = t (rhs);
522
+ Value tLhs = t (lhs);
523
+ Value tMask = t (mask, {2 , 0 , 1 });
524
+ return outerProd (tRhs, tLhs, res, lhsType, *reductionSize, tMask);
525
+ }
526
+ }
527
+ if (layout ({{k, m}, {k, n}, {n, m}})) {
528
+ if (auto reductionSize = getReductionSize (lhsType, 0 )) {
529
+ Value tMask = t (mask, {2 , 0 , 1 });
530
+ return outerProd (rhs, lhs, res, lhsType, *reductionSize, tMask);
531
+ }
532
+ }
533
+ if (layout ({{k, m}, {n, k}, {n, m}})) {
534
+ if (auto reductionSize = getReductionSize (lhsType, 0 )) {
535
+ Value tRhs = t (rhs);
536
+ Value tMask = t (mask, {2 , 0 , 1 });
537
+ return outerProd (tRhs, lhs, res, lhsType, *reductionSize, tMask);
538
+ }
497
539
}
498
- if (layout ({{k, m}, {k, n}, {n, m}}))
499
- return outerProd (rhs, lhs, res, lhsType, /* reductionDim=*/ 0 ,
500
- transposedMask);
501
- if (layout ({{k, m}, {n, k}, {n, m}}))
502
- return outerProd (t (rhs), lhs, res, lhsType, /* reductionDim=*/ 0 ,
503
- transposedMask);
504
540
return failure ();
505
541
}
506
542
@@ -514,24 +550,37 @@ struct UnrolledOuterProductGenerator
514
550
return failure ();
515
551
AffineExpr m, k;
516
552
bindDims (rewriter.getContext (), m, k);
517
- Value transposedMask = t (mask);
518
553
519
554
// Case mat-vec: transpose.
520
- if (layout ({{m, k}, {k}, {m}}))
521
- return outerProd (t (lhs), rhs, res, lhsType, /* reductionDim=*/ 1 ,
522
- transposedMask);
555
+ if (layout ({{m, k}, {k}, {m}})) {
556
+ if (auto reductionSize = getReductionSize (lhsType, 1 )) {
557
+ Value tLhs = t (lhs);
558
+ Value tMask = t (mask);
559
+ return outerProd (tLhs, rhs, res, lhsType, *reductionSize, tMask);
560
+ }
561
+ }
523
562
// Case mat-trans-vec: ready to go.
524
- if (layout ({{k, m}, {k}, {m}}))
525
- return outerProd (lhs, rhs, res, lhsType, /* reductionDim=*/ 0 ,
526
- transposedMask);
563
+ if (layout ({{k, m}, {k}, {m}})) {
564
+ if (auto reductionSize = getReductionSize (lhsType, 0 )) {
565
+ Value tMask = t (mask);
566
+ return outerProd (lhs, rhs, res, lhsType, *reductionSize, tMask);
567
+ }
568
+ }
527
569
// Case vec-mat: swap and transpose.
528
- if (layout ({{k}, {m, k}, {m}}))
529
- return outerProd (t (rhs), lhs, res, lhsType, /* reductionDim=*/ 0 ,
530
- transposedMask);
570
+ if (layout ({{k}, {m, k}, {m}})) {
571
+ if (auto reductionSize = getReductionSize (lhsType, 0 )) {
572
+ Value tRhs = t (rhs);
573
+ Value tMask = t (mask);
574
+ return outerProd (tRhs, lhs, res, lhsType, *reductionSize, tMask);
575
+ }
576
+ }
531
577
// Case vec-mat-trans: swap and ready to go.
532
- if (layout ({{k}, {k, m}, {m}}))
533
- return outerProd (rhs, lhs, res, lhsType, /* reductionDim=*/ 0 ,
534
- transposedMask);
578
+ if (layout ({{k}, {k, m}, {m}})) {
579
+ if (auto reductionSize = getReductionSize (lhsType, 0 )) {
580
+ Value tMask = t (mask);
581
+ return outerProd (rhs, lhs, res, lhsType, *reductionSize, tMask);
582
+ }
583
+ }
535
584
return failure ();
536
585
}
537
586
@@ -547,16 +596,20 @@ struct UnrolledOuterProductGenerator
547
596
548
597
// Case mat-vec: transpose.
549
598
if (layout ({{m, k}, {k}, {m}}))
550
- return outerProd (t (lhs), rhs, res, lhsType, /* reductionDim=*/ 1 , mask);
599
+ if (auto reductionSize = getReductionSize (lhsType, 1 ))
600
+ return outerProd (t (lhs), rhs, res, lhsType, *reductionSize, mask);
551
601
// Case mat-trans-vec: ready to go.
552
602
if (layout ({{k, m}, {k}, {m}}))
553
- return outerProd (lhs, rhs, res, lhsType, /* reductionDim=*/ 0 , mask);
603
+ if (auto reductionSize = getReductionSize (lhsType, 0 ))
604
+ return outerProd (lhs, rhs, res, lhsType, *reductionSize, mask);
554
605
// Case vec-mat: swap and transpose.
555
606
if (layout ({{k}, {m, k}, {m}}))
556
- return outerProd (t (rhs), lhs, res, lhsType, /* reductionDim=*/ 0 , mask);
607
+ if (auto reductionSize = getReductionSize (lhsType, 0 ))
608
+ return outerProd (t (rhs), lhs, res, lhsType, *reductionSize, mask);
557
609
// Case vec-mat-trans: swap and ready to go.
558
610
if (layout ({{k}, {k, m}, {m}}))
559
- return outerProd (rhs, lhs, res, lhsType, /* reductionDim=*/ 0 , mask);
611
+ if (auto reductionSize = getReductionSize (lhsType, 0 ))
612
+ return outerProd (rhs, lhs, res, lhsType, *reductionSize, mask);
560
613
return failure ();
561
614
}
562
615
0 commit comments