Skip to content

Commit 2490f7f

Browse files
[mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. (#127943)
With `tensor.expand_shape` allowing expanding dynamic dimension into multiple dynamic dimension, adapt the reshape propagation through expansion to handle cases where one dynamic dimension is expanded into multiple dynamic dimension. --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent fd24805 commit 2490f7f

File tree

2 files changed

+177
-308
lines changed

2 files changed

+177
-308
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 78 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/PatternMatch.h"
2626
#include "mlir/Support/LLVM.h"
2727
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28+
#include "mlir/Transforms/RegionUtils.h"
2829
#include <optional>
2930
#include <utility>
3031

@@ -590,44 +591,45 @@ class ExpansionInfo {
590591
// the expanded op.
591592
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
592593
ArrayRef<AffineMap> reassociationMaps,
593-
ArrayRef<int64_t> expandedShape,
594-
ArrayRef<int64_t> collapsedShape,
594+
ArrayRef<OpFoldResult> expandedShape,
595595
PatternRewriter &rewriter);
596596
unsigned getOrigOpNumDims() const { return reassociation.size(); }
597597
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
598598
ReassociationIndicesRef getExpandedDims(unsigned i) const {
599599
return reassociation[i];
600600
}
601-
ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
601+
ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
602602
return expandedShapeMap[i];
603603
}
604-
ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
604+
ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
605605

606606
private:
607607
/// Reassociation from the dimensions in the original operation to the
608608
/// dimension of the expanded operation.
609609
SmallVector<ReassociationIndices> reassociation;
610610
/// Mapping from extent of loops in the original operation, to the extent of
611611
/// loops in the expanded operation.
612-
SmallVector<SmallVector<int64_t>> expandedShapeMap;
612+
SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
613613
/// Extent of the loop in the original operation.
614-
SmallVector<int64_t> originalLoopExtent;
614+
SmallVector<OpFoldResult> originalLoopExtent;
615615
unsigned expandedOpNumDims;
616616
};
617617
} // namespace
618618

619619
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
620620
OpOperand *fusableOpOperand,
621621
ArrayRef<AffineMap> reassociationMaps,
622-
ArrayRef<int64_t> expandedShape,
623-
ArrayRef<int64_t> collapsedShape,
622+
ArrayRef<OpFoldResult> expandedShape,
624623
PatternRewriter &rewriter) {
625624
if (reassociationMaps.empty())
626625
return failure();
627626
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
628627

629-
SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
630-
originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
628+
OpBuilder::InsertionGuard g(rewriter);
629+
rewriter.setInsertionPoint(linalgOp);
630+
originalLoopExtent = llvm::map_to_vector(
631+
linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632+
[](Range r) { return r.size; });
631633

632634
reassociation.clear();
633635
expandedShapeMap.clear();
@@ -639,7 +641,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
639641
unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
640642
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
641643
numExpandedDims[pos] = foldedDims.getNumResults();
642-
ArrayRef<int64_t> shape =
644+
ArrayRef<OpFoldResult> shape =
643645
expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
644646
expandedShapeMap[pos].assign(shape.begin(), shape.end());
645647
}
@@ -660,33 +662,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
660662
return success();
661663
}
662664

663-
/// Expanding the body of a linalg operation requires adaptations of the
664-
/// accessed loop indices. Specifically, access of indices in the original
665-
/// operation need to be replaced with linearizations of indices in the expanded
666-
/// op. That requires the shape of the expanded dimensions to be static (at
667-
/// least all but the most significant). For now check that these are all
668-
/// statically sized. Note that this could be extended to handle dynamic case,
669-
/// but the implementation below uses `affine.apply` which seems to have issues
670-
/// when the shapes are not static.
671-
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
672-
const ExpansionInfo &expansionInfo,
673-
PatternRewriter &rewriter) {
674-
if (!linalgOp.hasIndexSemantics())
675-
return success();
676-
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
677-
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
678-
if (expandedShape.size() == 1)
679-
continue;
680-
for (int64_t shape : expandedShape.drop_front()) {
681-
if (ShapedType::isDynamic(shape)) {
682-
return rewriter.notifyMatchFailure(
683-
linalgOp, "cannot expand due to index semantics and dynamic dims");
684-
}
685-
}
686-
}
687-
return success();
688-
}
689-
690665
/// Return the indexing map to use in the expanded op for a given the
691666
/// `indexingMap` of the original operation.
692667
static AffineMap
@@ -706,18 +681,23 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
706681
builder.getContext());
707682
}
708683

709-
/// Return the type of the operand/result to use in the expanded op given the
710-
/// type in the original op.
711-
static RankedTensorType getExpandedType(RankedTensorType originalType,
712-
AffineMap indexingMap,
713-
const ExpansionInfo &expansionInfo) {
714-
SmallVector<int64_t> expandedShape;
684+
/// Return the shape and type of the operand/result to use in the expanded op
685+
/// given the type in the original op.
686+
static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687+
getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
688+
const ExpansionInfo &expansionInfo) {
689+
SmallVector<OpFoldResult> expandedShape;
715690
for (AffineExpr expr : indexingMap.getResults()) {
716691
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
717-
auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
692+
ArrayRef<OpFoldResult> dimExpansion =
693+
expansionInfo.getExpandedShapeOfDim(dim);
718694
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
719695
}
720-
return RankedTensorType::get(expandedShape, originalType.getElementType());
696+
SmallVector<int64_t> expandedStaticShape;
697+
std::tie(expandedStaticShape, std::ignore) =
698+
decomposeMixedValues(expandedShape);
699+
return {expandedShape, RankedTensorType::get(expandedStaticShape,
700+
originalType.getElementType())};
721701
}
722702

723703
/// Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -765,49 +745,28 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
765745
// Linearize the expanded indices of the original index dimension.
766746
OpBuilder::InsertionGuard guard(rewriter);
767747
rewriter.setInsertionPointAfter(indexOp);
768-
ArrayRef<int64_t> expandedDimsShape =
748+
ArrayRef<OpFoldResult> expandedDimsShape =
769749
expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
770750
SmallVector<Value> expandedIndices;
771751
expandedIndices.reserve(expandedDims.size() - 1);
772752
llvm::transform(
773753
expandedDims.drop_front(), std::back_inserter(expandedIndices),
774754
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
775-
Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
776-
for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
777-
assert(!ShapedType::isDynamic(std::get<0>(it)));
778-
AffineExpr idx, acc;
755+
OpFoldResult newIndex =
756+
rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
757+
for (auto [expandedShape, expandedIndex] :
758+
llvm::zip(expandedDimsShape, expandedIndices)) {
759+
AffineExpr idx, acc, shape;
779760
bindDims(rewriter.getContext(), idx, acc);
780-
newIndex = rewriter.create<affine::AffineApplyOp>(
781-
indexOp.getLoc(), idx + acc * std::get<0>(it),
782-
ValueRange{std::get<1>(it), newIndex});
783-
}
784-
rewriter.replaceOp(indexOp, newIndex);
785-
}
786-
}
787-
788-
/// Checks if a single dynamic dimension expanded into multiple dynamic
789-
/// dimensions.
790-
static LogicalResult
791-
validateDynamicDimExpansion(LinalgOp linalgOp,
792-
const ExpansionInfo &expansionInfo,
793-
PatternRewriter &rewriter) {
794-
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
795-
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
796-
if (expandedShape.size() == 1)
797-
continue;
798-
bool foundDynamic = false;
799-
for (int64_t shape : expandedShape) {
800-
if (!ShapedType::isDynamic(shape))
801-
continue;
802-
if (foundDynamic) {
803-
return rewriter.notifyMatchFailure(
804-
linalgOp, "cannot infer expanded shape with multiple dynamic "
805-
"dims in the same reassociation group");
806-
}
807-
foundDynamic = true;
761+
bindSymbols(rewriter.getContext(), shape);
762+
newIndex = affine::makeComposedFoldedAffineApply(
763+
rewriter, indexOp.getLoc(), idx + acc * shape,
764+
ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
808765
}
766+
Value newIndexVal =
767+
getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
768+
rewriter.replaceOp(indexOp, newIndexVal);
809769
}
810-
return success();
811770
}
812771

813772
// Create an expanded transpose op.
@@ -910,31 +869,34 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
910869
"preconditions for fuse operation failed");
911870

912871
Location loc = linalgOp.getLoc();
913-
// Check if reshape is expanding or collapsing.
914-
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
915-
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
916-
bool isExpanding = (expandingReshapeOp != nullptr);
917-
RankedTensorType expandedType = isExpanding
918-
? expandingReshapeOp.getResultType()
919-
: collapsingReshapeOp.getSrcType();
920-
RankedTensorType collapsedType = isExpanding
921-
? expandingReshapeOp.getSrcType()
922-
: collapsingReshapeOp.getResultType();
872+
SmallVector<OpFoldResult> expandedShape, collapsedShape;
873+
SmallVector<AffineMap, 4> reassociationIndices;
874+
Value src;
875+
if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
876+
// Try to move the dynamic dimensions in output shape before the `linalgOp`
877+
// to maintain SSA validity
878+
if (failed(moveValueDefinitions(
879+
rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
880+
return std::nullopt;
881+
882+
expandedShape = expandingReshapeOp.getMixedOutputShape();
883+
reassociationIndices = expandingReshapeOp.getReassociationMaps();
884+
src = expandingReshapeOp.getSrc();
885+
} else {
886+
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887+
if (!collapsingReshapeOp)
888+
return std::nullopt;
889+
890+
expandedShape = tensor::getMixedSizes(
891+
rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892+
reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893+
src = collapsingReshapeOp.getSrc();
894+
}
923895

924896
ExpansionInfo expansionInfo;
925-
if (failed(expansionInfo.compute(
926-
linalgOp, fusableOpOperand,
927-
isExpanding ? expandingReshapeOp.getReassociationMaps()
928-
: collapsingReshapeOp.getReassociationMaps(),
929-
expandedType.getShape(), collapsedType.getShape(), rewriter)))
930-
return std::nullopt;
931-
932-
// TODO: With the support of multiple dynamic dims expansion in
933-
// tensor.expand_shape op, this case can be handled.
934-
if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
935-
return std::nullopt;
936-
937-
if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
897+
if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898+
reassociationIndices, expandedShape,
899+
rewriter)))
938900
return std::nullopt;
939901

940902
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
@@ -950,15 +912,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
950912
expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
951913
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
952914
if (opOperand == fusableOpOperand) {
953-
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
954-
: collapsingReshapeOp.getSrc());
915+
expandedOpOperands.push_back(src);
955916
continue;
956917
}
957918
if (auto opOperandType =
958919
dyn_cast<RankedTensorType>(opOperand->get().getType())) {
959920
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
960-
RankedTensorType expandedOperandType =
961-
getExpandedType(opOperandType, indexingMap, expansionInfo);
921+
SmallVector<OpFoldResult> expandedOperandShape;
922+
RankedTensorType expandedOperandType;
923+
std::tie(expandedOperandShape, expandedOperandType) =
924+
getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
962925
if (expandedOperandType != opOperand->get().getType()) {
963926
// Reshape the operand to get the right type.
964927
SmallVector<ReassociationIndices> reassociation =
@@ -972,7 +935,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
972935
/*isExpandingReshape=*/true)))
973936
return std::nullopt;
974937
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
975-
loc, expandedOperandType, opOperand->get(), reassociation));
938+
loc, expandedOperandType, opOperand->get(), reassociation,
939+
expandedOperandShape));
976940
continue;
977941
}
978942
}
@@ -983,8 +947,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
983947
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
984948
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
985949
auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
986-
RankedTensorType expandedOutputType =
987-
getExpandedType(opOperandType, indexingMap, expansionInfo);
950+
SmallVector<OpFoldResult> expandedOutputShape;
951+
RankedTensorType expandedOutputType;
952+
std::tie(expandedOutputShape, expandedOutputType) =
953+
getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
988954
if (expandedOutputType != opOperand.get().getType()) {
989955
SmallVector<ReassociationIndices> reassociation =
990956
getReassociationForExpansion(indexingMap, expansionInfo);
@@ -997,7 +963,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
997963
/*isExpandingReshape=*/true)))
998964
return std::nullopt;
999965
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
1000-
loc, expandedOutputType, opOperand.get(), reassociation));
966+
loc, expandedOutputType, opOperand.get(), reassociation,
967+
expandedOutputShape));
1001968
} else {
1002969
outputs.push_back(opOperand.get());
1003970
}

0 commit comments

Comments
 (0)