@@ -60,9 +60,51 @@ struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
60
60
}
61
61
};
62
62
63
+ struct SelfConcatToTile : public OpRewritePattern <tosa::ConcatOp> {
64
+ using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
65
+
66
+ LogicalResult matchAndRewrite (tosa::ConcatOp concatOp,
67
+ PatternRewriter &rewriter) const override {
68
+ if (llvm::all_equal (concatOp->getUsers ())) {
69
+ const auto concatUser = llvm::dyn_cast<tosa::ConcatOp>(
70
+ concatOp->getUses ().begin ()->getOwner ());
71
+ if (concatUser) {
72
+ // Try folding the concat into its consumer before rewriting it to a
73
+ // tile.
74
+ SmallVector<Value> replacementValues;
75
+ auto foldResult = rewriter.tryFold (concatUser, replacementValues);
76
+ if (foldResult.succeeded ()) {
77
+ if (!replacementValues.empty ()) {
78
+ rewriter.replaceOp (concatUser, replacementValues);
79
+ }
80
+ return success ();
81
+ }
82
+ }
83
+ }
84
+
85
+ if (!llvm::all_equal (concatOp->getOperands ())) {
86
+ return rewriter.notifyMatchFailure (
87
+ concatOp, " Requires all operands to be the same" );
88
+ }
89
+ const auto concatType = dyn_cast<ShapedType>(concatOp.getType ());
90
+ if (!concatType || !concatType.hasRank ()) {
91
+ return rewriter.notifyMatchFailure (concatOp,
92
+ " Requires concat to be ranked" );
93
+ }
94
+ SmallVector<int64_t > multiplies (concatType.getRank (), 1 );
95
+ multiplies[concatOp.getAxis ()] = concatOp->getNumOperands ();
96
+ auto tileOp = rewriter.createOrFold <tosa::TileOp>(
97
+ concatOp->getLoc (), concatOp.getType (), concatOp->getOperand (0 ),
98
+ multiplies);
99
+ rewriter.replaceOp (concatOp, {tileOp});
100
+ return success ();
101
+ }
102
+ };
103
+
63
104
void ConcatOp::getCanonicalizationPatterns (RewritePatternSet &results,
64
105
MLIRContext *context) {
65
106
results.add <ConcatOptimization>(context);
107
+ results.add <SelfConcatToTile>(context);
66
108
}
67
109
68
110
struct SqrtReciprocalOptimization : public OpRewritePattern <tosa::PowOp> {
@@ -611,42 +653,120 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
611
653
612
654
llvm::SmallVector<int64_t > sliceStart (sliceOp.getStart ());
613
655
llvm::ArrayRef<int64_t > sliceSize = sliceOp.getSize ();
614
-
615
- // Validate slice on the concatenated axis. Slicing along this
616
- // axis should span only one of the inputs to the concatenate
617
- // operation.
618
- std::optional<Value> replaceWithSlice;
656
+ llvm::SmallVector<Value> requiredConcatInputs;
657
+ int64_t processedOriginalConcatInputSize = 0 ;
658
+ int64_t droppedConcatInputSize = 0 ;
619
659
for (auto input : inputs) {
620
- auto inputType = dyn_cast<RankedTensorType>(input.getType ());
660
+ const auto inputType = dyn_cast<RankedTensorType>(input.getType ());
621
661
if (!inputType || !inputType.hasStaticShape ())
622
662
return rewriter.notifyMatchFailure (
623
663
sliceOp, " concat input must be a static ranked tensor" );
624
-
625
- if (sliceStart[axis] >= 0 &&
626
- (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize (axis)) {
627
- replaceWithSlice = rewriter
628
- .create <tosa::SliceOp>(
629
- sliceOp.getLoc (), sliceOp.getType (), input,
630
- rewriter.getDenseI64ArrayAttr (sliceStart),
631
- rewriter.getDenseI64ArrayAttr (sliceSize))
632
- .getResult ();
633
- break ;
664
+ if (processedOriginalConcatInputSize <
665
+ (sliceStart[axis] + sliceSize[axis]) &&
666
+ (processedOriginalConcatInputSize + inputType.getDimSize (axis)) >
667
+ sliceStart[axis]) {
668
+ if (requiredConcatInputs.empty ()) {
669
+ droppedConcatInputSize = processedOriginalConcatInputSize;
670
+ }
671
+ requiredConcatInputs.push_back (input);
634
672
}
635
- sliceStart[axis] -= inputType.getDimSize (axis);
673
+ processedOriginalConcatInputSize += inputType.getDimSize (axis);
674
+ }
675
+ if (requiredConcatInputs.size () == concatOp->getNumOperands ()) {
676
+ return rewriter.notifyMatchFailure (
677
+ sliceOp, " Could not reduce number of inputs to preceding concat" );
678
+ }
679
+ if (requiredConcatInputs.size () != 1 && !concatOp->hasOneUse ()) {
680
+ return rewriter.notifyMatchFailure (
681
+ sliceOp,
682
+ " Preceding concat must have a single use" ); // Do not introduce new
683
+ // concats
684
+ }
685
+ if (requiredConcatInputs.empty ()) {
686
+ return rewriter.notifyMatchFailure (
687
+ sliceOp, " degenerate slice with zero sized dim in output" );
636
688
}
689
+ sliceStart[axis] -= droppedConcatInputSize;
690
+ auto newConcat = rewriter.create <tosa::ConcatOp>(
691
+ concatOp->getLoc (), requiredConcatInputs, axis);
692
+ auto newSlice = rewriter.create <tosa::SliceOp>(
693
+ sliceOp->getLoc (), sliceOp.getType (), newConcat,
694
+ rewriter.getDenseI64ArrayAttr (sliceStart),
695
+ rewriter.getDenseI64ArrayAttr (sliceSize));
696
+ rewriter.replaceOp (sliceOp, newSlice);
697
+ return success ();
698
+ }
699
+ };
700
+
701
+ // / This patterns adjust the multipliers of a tile followed by a slice to only
702
+ // / tile as much data as it is required by the slice
703
+ struct TileSliceOptimization : public OpRewritePattern <tosa::SliceOp> {
704
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
705
+
706
+ LogicalResult matchAndRewrite (tosa::SliceOp sliceOp,
707
+ PatternRewriter &rewriter) const override {
708
+ Value sliceInput = sliceOp.getInput1 ();
709
+ auto tileOp = sliceInput.getDefiningOp <tosa::TileOp>();
710
+ if (!tileOp)
711
+ return rewriter.notifyMatchFailure (sliceOp,
712
+ " slice input must be tile operation" );
713
+ if (!tileOp->hasOneUse ())
714
+ return rewriter.notifyMatchFailure (
715
+ sliceOp, " preceding tile must have a single use" ); // Do not insert
716
+ // additional tiles
637
717
638
- if (!replaceWithSlice)
718
+ const auto tileOpInputType =
719
+ dyn_cast<RankedTensorType>(tileOp->getOperand (0 ).getType ());
720
+ if (!tileOpInputType || !tileOpInputType.hasStaticShape ())
639
721
return rewriter.notifyMatchFailure (
640
- sliceOp, " corresponding concat input not found for slice" );
722
+ sliceOp, " input to preceding tile op must be a static ranked tensor" );
723
+ llvm::SmallVector<int64_t > requiredMultipliers;
724
+ llvm::SmallVector<int64_t > newTileStarts;
725
+ requiredMultipliers.reserve (tileOpInputType.getRank ());
726
+ newTileStarts.reserve (tileOpInputType.getRank ());
727
+ for (auto [axis, sliceStart, sliceSize] :
728
+ llvm::enumerate (sliceOp.getStart (), sliceOp.getSize ())) {
729
+ if (sliceSize <= 0 ) {
730
+ return rewriter.notifyMatchFailure (
731
+ sliceOp, " degenerate slice with zero sized dim" );
732
+ }
733
+ const int64_t tileInputDimSize = tileOpInputType.getDimSize (axis);
734
+ const int64_t sliceOffsetInNewFirstTile = sliceStart % tileInputDimSize;
735
+ const int64_t sliceSizeInFirstTile =
736
+ std::min (tileInputDimSize - sliceOffsetInNewFirstTile, sliceSize);
737
+ assert (sliceSizeInFirstTile > 0 );
738
+ const int64_t requiredMultiplierWithoutFirstTile =
739
+ llvm::divideCeil (sliceSize - sliceSizeInFirstTile, tileInputDimSize);
740
+ const int64_t requiredMultiplier =
741
+ requiredMultiplierWithoutFirstTile + (sliceSizeInFirstTile != 0 );
742
+ assert (requiredMultiplier <= tileOp.getMultiples ()[axis]);
743
+ requiredMultipliers.push_back (requiredMultiplier);
744
+ newTileStarts.push_back (sliceOffsetInNewFirstTile);
745
+ }
746
+ if (requiredMultipliers == tileOp.getMultiples ())
747
+ return rewriter.notifyMatchFailure (
748
+ sliceOp, " could not reduce multipliers in preceding tile" );
641
749
642
- rewriter.replaceOp (sliceOp, replaceWithSlice.value ());
750
+ llvm::SmallVector<int64_t > newTileShape (tileOpInputType.getShape ());
751
+ for (auto [newShape, multiplier] :
752
+ llvm::zip_equal (newTileShape, requiredMultipliers)) {
753
+ newShape *= multiplier;
754
+ }
755
+ auto newTile = rewriter.create <tosa::TileOp>(
756
+ tileOp->getLoc (), tileOpInputType.clone (newTileShape),
757
+ tileOp->getOperand (0 ), requiredMultipliers);
758
+ auto newSlice = rewriter.create <tosa::SliceOp>(
759
+ sliceOp->getLoc (), sliceOp.getType (), newTile,
760
+ rewriter.getDenseI64ArrayAttr (newTileStarts), sliceOp.getSizeAttr ());
761
+ rewriter.replaceOp (sliceOp, newSlice);
643
762
return success ();
644
763
}
645
764
};
646
765
647
766
void SliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
648
767
MLIRContext *context) {
649
768
results.add <ConcatSliceOptimization>(context);
769
+ results.add <TileSliceOptimization>(context);
650
770
}
651
771
652
772
struct MinToClampOptimization : public OpRewritePattern <tosa::MinimumOp> {
@@ -1321,6 +1441,21 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1321
1441
bool allOnes = llvm::all_of (getMultiples (), [](int64_t v) { return v == 1 ; });
1322
1442
if (allOnes && getInput1 ().getType () == getType ())
1323
1443
return getInput1 ();
1444
+
1445
+ if (auto inputTile = getInput1 ().getDefiningOp <TileOp>()) {
1446
+ if (!inputTile->hasOneUse ()) {
1447
+ return {};
1448
+ }
1449
+ llvm::SmallVector<int64_t > newMultiplies{getMultiples ()};
1450
+ for (auto [idx, multiplier] : llvm::enumerate (inputTile.getMultiples ())) {
1451
+ newMultiplies[idx] *= multiplier;
1452
+ }
1453
+ setMultiples (newMultiplies);
1454
+ setOperand (inputTile->getOperand (0 ));
1455
+ getOperation ()->setLoc (
1456
+ FusedLoc::get (getContext (), {inputTile->getLoc (), getLoc ()}));
1457
+ return getResult ();
1458
+ }
1324
1459
return {};
1325
1460
}
1326
1461
0 commit comments