Skip to content

Commit 8859aa8

Browse files
authored
Merge pull request #486 from Xilinx/jrickert.fold_concat
[TOSA] Extend folding/canonicalization for concat, tile and slice.
2 parents c733a76 + 2afdeee commit 8859aa8

File tree

3 files changed

+320
-35
lines changed

3 files changed

+320
-35
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 155 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,51 @@ struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
6060
}
6161
};
6262

63+
struct SelfConcatToTile : public OpRewritePattern<tosa::ConcatOp> {
64+
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
65+
66+
LogicalResult matchAndRewrite(tosa::ConcatOp concatOp,
67+
PatternRewriter &rewriter) const override {
68+
if (llvm::all_equal(concatOp->getUsers())) {
69+
const auto concatUser = llvm::dyn_cast<tosa::ConcatOp>(
70+
concatOp->getUses().begin()->getOwner());
71+
if (concatUser) {
72+
// Try folding the concat into its consumer before rewriting it to a
73+
// tile.
74+
SmallVector<Value> replacementValues;
75+
auto foldResult = rewriter.tryFold(concatUser, replacementValues);
76+
if (foldResult.succeeded()) {
77+
if (!replacementValues.empty()) {
78+
rewriter.replaceOp(concatUser, replacementValues);
79+
}
80+
return success();
81+
}
82+
}
83+
}
84+
85+
if (!llvm::all_equal(concatOp->getOperands())) {
86+
return rewriter.notifyMatchFailure(
87+
concatOp, "Requires all operands to be the same");
88+
}
89+
const auto concatType = dyn_cast<ShapedType>(concatOp.getType());
90+
if (!concatType || !concatType.hasRank()) {
91+
return rewriter.notifyMatchFailure(concatOp,
92+
"Requires concat to be ranked");
93+
}
94+
SmallVector<int64_t> multiplies(concatType.getRank(), 1);
95+
multiplies[concatOp.getAxis()] = concatOp->getNumOperands();
96+
auto tileOp = rewriter.createOrFold<tosa::TileOp>(
97+
concatOp->getLoc(), concatOp.getType(), concatOp->getOperand(0),
98+
multiplies);
99+
rewriter.replaceOp(concatOp, {tileOp});
100+
return success();
101+
}
102+
};
103+
63104
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
64105
MLIRContext *context) {
65106
results.add<ConcatOptimization>(context);
107+
results.add<SelfConcatToTile>(context);
66108
}
67109

68110
struct SqrtReciprocalOptimization : public OpRewritePattern<tosa::PowOp> {
@@ -611,42 +653,120 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
611653

612654
llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
613655
llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
614-
615-
// Validate slice on the concatenated axis. Slicing along this
616-
// axis should span only one of the inputs to the concatenate
617-
// operation.
618-
std::optional<Value> replaceWithSlice;
656+
llvm::SmallVector<Value> requiredConcatInputs;
657+
int64_t processedOriginalConcatInputSize = 0;
658+
int64_t droppedConcatInputSize = 0;
619659
for (auto input : inputs) {
620-
auto inputType = dyn_cast<RankedTensorType>(input.getType());
660+
const auto inputType = dyn_cast<RankedTensorType>(input.getType());
621661
if (!inputType || !inputType.hasStaticShape())
622662
return rewriter.notifyMatchFailure(
623663
sliceOp, "concat input must be a static ranked tensor");
624-
625-
if (sliceStart[axis] >= 0 &&
626-
(sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
627-
replaceWithSlice = rewriter
628-
.create<tosa::SliceOp>(
629-
sliceOp.getLoc(), sliceOp.getType(), input,
630-
rewriter.getDenseI64ArrayAttr(sliceStart),
631-
rewriter.getDenseI64ArrayAttr(sliceSize))
632-
.getResult();
633-
break;
664+
if (processedOriginalConcatInputSize <
665+
(sliceStart[axis] + sliceSize[axis]) &&
666+
(processedOriginalConcatInputSize + inputType.getDimSize(axis)) >
667+
sliceStart[axis]) {
668+
if (requiredConcatInputs.empty()) {
669+
droppedConcatInputSize = processedOriginalConcatInputSize;
670+
}
671+
requiredConcatInputs.push_back(input);
634672
}
635-
sliceStart[axis] -= inputType.getDimSize(axis);
673+
processedOriginalConcatInputSize += inputType.getDimSize(axis);
674+
}
675+
if (requiredConcatInputs.size() == concatOp->getNumOperands()) {
676+
return rewriter.notifyMatchFailure(
677+
sliceOp, "Could not reduce number of inputs to preceding concat");
678+
}
679+
if (requiredConcatInputs.size() != 1 && !concatOp->hasOneUse()) {
680+
return rewriter.notifyMatchFailure(
681+
sliceOp,
682+
"Preceding concat must have a single use"); // Do not introduce new
683+
// concats
684+
}
685+
if (requiredConcatInputs.empty()) {
686+
return rewriter.notifyMatchFailure(
687+
sliceOp, "degenerate slice with zero sized dim in output");
636688
}
689+
sliceStart[axis] -= droppedConcatInputSize;
690+
auto newConcat = rewriter.create<tosa::ConcatOp>(
691+
concatOp->getLoc(), requiredConcatInputs, axis);
692+
auto newSlice = rewriter.create<tosa::SliceOp>(
693+
sliceOp->getLoc(), sliceOp.getType(), newConcat,
694+
rewriter.getDenseI64ArrayAttr(sliceStart),
695+
rewriter.getDenseI64ArrayAttr(sliceSize));
696+
rewriter.replaceOp(sliceOp, newSlice);
697+
return success();
698+
}
699+
};
700+
701+
/// This patterns adjust the multipliers of a tile followed by a slice to only
702+
/// tile as much data as it is required by the slice
703+
struct TileSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
704+
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
705+
706+
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
707+
PatternRewriter &rewriter) const override {
708+
Value sliceInput = sliceOp.getInput1();
709+
auto tileOp = sliceInput.getDefiningOp<tosa::TileOp>();
710+
if (!tileOp)
711+
return rewriter.notifyMatchFailure(sliceOp,
712+
"slice input must be tile operation");
713+
if (!tileOp->hasOneUse())
714+
return rewriter.notifyMatchFailure(
715+
sliceOp, "preceding tile must have a single use"); // Do not insert
716+
// additional tiles
637717

638-
if (!replaceWithSlice)
718+
const auto tileOpInputType =
719+
dyn_cast<RankedTensorType>(tileOp->getOperand(0).getType());
720+
if (!tileOpInputType || !tileOpInputType.hasStaticShape())
639721
return rewriter.notifyMatchFailure(
640-
sliceOp, "corresponding concat input not found for slice");
722+
sliceOp, "input to preceding tile op must be a static ranked tensor");
723+
llvm::SmallVector<int64_t> requiredMultipliers;
724+
llvm::SmallVector<int64_t> newTileStarts;
725+
requiredMultipliers.reserve(tileOpInputType.getRank());
726+
newTileStarts.reserve(tileOpInputType.getRank());
727+
for (auto [axis, sliceStart, sliceSize] :
728+
llvm::enumerate(sliceOp.getStart(), sliceOp.getSize())) {
729+
if (sliceSize <= 0) {
730+
return rewriter.notifyMatchFailure(
731+
sliceOp, "degenerate slice with zero sized dim");
732+
}
733+
const int64_t tileInputDimSize = tileOpInputType.getDimSize(axis);
734+
const int64_t sliceOffsetInNewFirstTile = sliceStart % tileInputDimSize;
735+
const int64_t sliceSizeInFirstTile =
736+
std::min(tileInputDimSize - sliceOffsetInNewFirstTile, sliceSize);
737+
assert(sliceSizeInFirstTile > 0);
738+
const int64_t requiredMultiplierWithoutFirstTile =
739+
llvm::divideCeil(sliceSize - sliceSizeInFirstTile, tileInputDimSize);
740+
const int64_t requiredMultiplier =
741+
requiredMultiplierWithoutFirstTile + (sliceSizeInFirstTile != 0);
742+
assert(requiredMultiplier <= tileOp.getMultiples()[axis]);
743+
requiredMultipliers.push_back(requiredMultiplier);
744+
newTileStarts.push_back(sliceOffsetInNewFirstTile);
745+
}
746+
if (requiredMultipliers == tileOp.getMultiples())
747+
return rewriter.notifyMatchFailure(
748+
sliceOp, "could not reduce multipliers in preceding tile");
641749

642-
rewriter.replaceOp(sliceOp, replaceWithSlice.value());
750+
llvm::SmallVector<int64_t> newTileShape(tileOpInputType.getShape());
751+
for (auto [newShape, multiplier] :
752+
llvm::zip_equal(newTileShape, requiredMultipliers)) {
753+
newShape *= multiplier;
754+
}
755+
auto newTile = rewriter.create<tosa::TileOp>(
756+
tileOp->getLoc(), tileOpInputType.clone(newTileShape),
757+
tileOp->getOperand(0), requiredMultipliers);
758+
auto newSlice = rewriter.create<tosa::SliceOp>(
759+
sliceOp->getLoc(), sliceOp.getType(), newTile,
760+
rewriter.getDenseI64ArrayAttr(newTileStarts), sliceOp.getSizeAttr());
761+
rewriter.replaceOp(sliceOp, newSlice);
643762
return success();
644763
}
645764
};
646765

647766
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
648767
MLIRContext *context) {
649768
results.add<ConcatSliceOptimization>(context);
769+
results.add<TileSliceOptimization>(context);
650770
}
651771

652772
struct MinToClampOptimization : public OpRewritePattern<tosa::MinimumOp> {
@@ -1321,6 +1441,21 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
13211441
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
13221442
if (allOnes && getInput1().getType() == getType())
13231443
return getInput1();
1444+
1445+
if (auto inputTile = getInput1().getDefiningOp<TileOp>()) {
1446+
if (!inputTile->hasOneUse()) {
1447+
return {};
1448+
}
1449+
llvm::SmallVector<int64_t> newMultiplies{getMultiples()};
1450+
for (auto [idx, multiplier] : llvm::enumerate(inputTile.getMultiples())) {
1451+
newMultiplies[idx] *= multiplier;
1452+
}
1453+
setMultiples(newMultiplies);
1454+
setOperand(inputTile->getOperand(0));
1455+
getOperation()->setLoc(
1456+
FusedLoc::get(getContext(), {inputTile->getLoc(), getLoc()}));
1457+
return getResult();
1458+
}
13241459
return {};
13251460
}
13261461

0 commit comments

Comments
 (0)