Skip to content

Commit 09fa404

Browse files
[mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes.
With `tensor.expand_shape` allowing expanding dynamic dimension into multiple dynamic dimension, adapt the reshape propagation through expansion to handle cases where one dynamic dimension is expanded into multiple dynamic dimension. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 7c24041 commit 09fa404

File tree

2 files changed

+131
-295
lines changed

2 files changed

+131
-295
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 70 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -595,44 +595,45 @@ class ExpansionInfo {
595595
// the expanded op.
596596
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
597597
ArrayRef<AffineMap> reassociationMaps,
598-
ArrayRef<int64_t> expandedShape,
599-
ArrayRef<int64_t> collapsedShape,
598+
ArrayRef<OpFoldResult> expandedShape,
600599
PatternRewriter &rewriter);
601600
unsigned getOrigOpNumDims() const { return reassociation.size(); }
602601
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
603602
ReassociationIndicesRef getExpandedDims(unsigned i) const {
604603
return reassociation[i];
605604
}
606-
ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
605+
ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
607606
return expandedShapeMap[i];
608607
}
609-
ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
608+
ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
610609

611610
private:
612611
/// Reassociation from the dimensions in the original operation to the
613612
/// dimension of the expanded operation.
614613
SmallVector<ReassociationIndices> reassociation;
615614
/// Mapping from extent of loops in the original operation, to the extent of
616615
/// loops in the expanded operation.
617-
SmallVector<SmallVector<int64_t>> expandedShapeMap;
616+
SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
618617
/// Extent of the loop in the original operation.
619-
SmallVector<int64_t> originalLoopExtent;
618+
SmallVector<OpFoldResult> originalLoopExtent;
620619
unsigned expandedOpNumDims;
621620
};
622621
} // namespace
623622

624623
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
625624
OpOperand *fusableOpOperand,
626625
ArrayRef<AffineMap> reassociationMaps,
627-
ArrayRef<int64_t> expandedShape,
628-
ArrayRef<int64_t> collapsedShape,
626+
ArrayRef<OpFoldResult> expandedShape,
629627
PatternRewriter &rewriter) {
630628
if (reassociationMaps.empty())
631629
return failure();
632630
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
633631

634-
SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
635-
originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
632+
OpBuilder::InsertionGuard g(rewriter);
633+
rewriter.setInsertionPoint(linalgOp);
634+
originalLoopExtent = llvm::map_to_vector(
635+
linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
636+
[](Range r) { return r.size; });
636637

637638
reassociation.clear();
638639
expandedShapeMap.clear();
@@ -644,7 +645,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
644645
unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
645646
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
646647
numExpandedDims[pos] = foldedDims.getNumResults();
647-
ArrayRef<int64_t> shape =
648+
ArrayRef<OpFoldResult> shape =
648649
expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
649650
expandedShapeMap[pos].assign(shape.begin(), shape.end());
650651
}
@@ -665,33 +666,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
665666
return success();
666667
}
667668

668-
/// Expanding the body of a linalg operation requires adaptations of the
669-
/// accessed loop indices. Specifically, access of indices in the original
670-
/// operation need to be replaced with linearizations of indices in the expanded
671-
/// op. That requires the shape of the expanded dimensions to be static (at
672-
/// least all but the most significant). For now check that these are all
673-
/// statically sized. Note that this could be extended to handle dynamic case,
674-
/// but the implementation below uses `affine.apply` which seems to have issues
675-
/// when the shapes are not static.
676-
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
677-
const ExpansionInfo &expansionInfo,
678-
PatternRewriter &rewriter) {
679-
if (!linalgOp.hasIndexSemantics())
680-
return success();
681-
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
682-
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
683-
if (expandedShape.size() == 1)
684-
continue;
685-
for (int64_t shape : expandedShape.drop_front()) {
686-
if (ShapedType::isDynamic(shape)) {
687-
return rewriter.notifyMatchFailure(
688-
linalgOp, "cannot expand due to index semantics and dynamic dims");
689-
}
690-
}
691-
}
692-
return success();
693-
}
694-
695669
/// Return the indexing map to use in the expanded op for a given the
696670
/// `indexingMap` of the original operation.
697671
static AffineMap
@@ -713,16 +687,28 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
713687

714688
/// Return the type of the operand/result to use in the expanded op given the
715689
/// type in the original op.
716-
static RankedTensorType getExpandedType(RankedTensorType originalType,
717-
AffineMap indexingMap,
718-
const ExpansionInfo &expansionInfo) {
719-
SmallVector<int64_t> expandedShape;
690+
static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
691+
getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
692+
const ExpansionInfo &expansionInfo) {
693+
SmallVector<int64_t> expandedStaticShape;
694+
SmallVector<OpFoldResult> expandedShape;
720695
for (AffineExpr expr : indexingMap.getResults()) {
721696
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
722-
auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
697+
ArrayRef<OpFoldResult> dimExpansion =
698+
expansionInfo.getExpandedShapeOfDim(dim);
699+
llvm::append_range(expandedStaticShape,
700+
llvm::map_range(dimExpansion, [](OpFoldResult ofr) {
701+
std::optional<int64_t> staticShape =
702+
getConstantIntValue(ofr);
703+
if (staticShape) {
704+
return staticShape.value();
705+
}
706+
return ShapedType::kDynamic;
707+
}));
723708
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
724709
}
725-
return RankedTensorType::get(expandedShape, originalType.getElementType());
710+
return {expandedShape, RankedTensorType::get(expandedStaticShape,
711+
originalType.getElementType())};
726712
}
727713

728714
/// Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -770,49 +756,27 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
770756
// Linearize the expanded indices of the original index dimension.
771757
OpBuilder::InsertionGuard guard(rewriter);
772758
rewriter.setInsertionPointAfter(indexOp);
773-
ArrayRef<int64_t> expandedDimsShape =
759+
ArrayRef<OpFoldResult> expandedDimsShape =
774760
expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
775761
SmallVector<Value> expandedIndices;
776762
expandedIndices.reserve(expandedDims.size() - 1);
777763
llvm::transform(
778764
expandedDims.drop_front(), std::back_inserter(expandedIndices),
779765
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
780-
Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
766+
OpFoldResult newIndex =
767+
rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
781768
for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
782-
assert(!ShapedType::isDynamic(std::get<0>(it)));
783-
AffineExpr idx, acc;
769+
AffineExpr idx, acc, shape;
784770
bindDims(rewriter.getContext(), idx, acc);
785-
newIndex = rewriter.create<affine::AffineApplyOp>(
786-
indexOp.getLoc(), idx + acc * std::get<0>(it),
787-
ValueRange{std::get<1>(it), newIndex});
788-
}
789-
rewriter.replaceOp(indexOp, newIndex);
790-
}
791-
}
792-
793-
/// Checks if a single dynamic dimension expanded into multiple dynamic
794-
/// dimensions.
795-
static LogicalResult
796-
validateDynamicDimExpansion(LinalgOp linalgOp,
797-
const ExpansionInfo &expansionInfo,
798-
PatternRewriter &rewriter) {
799-
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
800-
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
801-
if (expandedShape.size() == 1)
802-
continue;
803-
bool foundDynamic = false;
804-
for (int64_t shape : expandedShape) {
805-
if (!ShapedType::isDynamic(shape))
806-
continue;
807-
if (foundDynamic) {
808-
return rewriter.notifyMatchFailure(
809-
linalgOp, "cannot infer expanded shape with multiple dynamic "
810-
"dims in the same reassociation group");
811-
}
812-
foundDynamic = true;
771+
bindSymbols(rewriter.getContext(), shape);
772+
newIndex = affine::makeComposedFoldedAffineApply(
773+
rewriter, indexOp.getLoc(), idx + acc * shape,
774+
ArrayRef<OpFoldResult>{std::get<1>(it), newIndex, std::get<0>(it)});
813775
}
776+
Value newIndexVal =
777+
getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
778+
rewriter.replaceOp(indexOp, newIndexVal);
814779
}
815-
return success();
816780
}
817781

818782
/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
@@ -826,31 +790,25 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
826790
"preconditions for fuse operation failed");
827791

828792
Location loc = linalgOp.getLoc();
829-
// Check if reshape is expanding or collapsing.
830-
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
831-
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
832-
bool isExpanding = (expandingReshapeOp != nullptr);
833-
RankedTensorType expandedType = isExpanding
834-
? expandingReshapeOp.getResultType()
835-
: collapsingReshapeOp.getSrcType();
836-
RankedTensorType collapsedType = isExpanding
837-
? expandingReshapeOp.getSrcType()
838-
: collapsingReshapeOp.getResultType();
793+
SmallVector<OpFoldResult> expandedShape, collapsedShape;
794+
SmallVector<AffineMap, 4> reassociationIndices;
795+
Value src;
796+
if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
797+
expandedShape = expandingReshapeOp.getMixedOutputShape();
798+
reassociationIndices = expandingReshapeOp.getReassociationMaps();
799+
src = expandingReshapeOp.getSrc();
800+
} else {
801+
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
802+
expandedShape = tensor::getMixedSizes(
803+
rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
804+
reassociationIndices = collapsingReshapeOp.getReassociationMaps();
805+
src = collapsingReshapeOp.getSrc();
806+
}
839807

840808
ExpansionInfo expansionInfo;
841809
if (failed(expansionInfo.compute(
842-
linalgOp, fusableOpOperand,
843-
isExpanding ? expandingReshapeOp.getReassociationMaps()
844-
: collapsingReshapeOp.getReassociationMaps(),
845-
expandedType.getShape(), collapsedType.getShape(), rewriter)))
846-
return std::nullopt;
847-
848-
// TODO: With the support of multiple dynamic dims expansion in
849-
// tensor.expand_shape op, this case can be handled.
850-
if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
851-
return std::nullopt;
852-
853-
if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
810+
linalgOp, fusableOpOperand, reassociationIndices,
811+
expandedShape, rewriter)))
854812
return std::nullopt;
855813

856814
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
@@ -866,15 +824,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
866824
expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
867825
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
868826
if (opOperand == fusableOpOperand) {
869-
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
870-
: collapsingReshapeOp.getSrc());
827+
expandedOpOperands.push_back(src);
871828
continue;
872829
}
873830
if (auto opOperandType =
874831
dyn_cast<RankedTensorType>(opOperand->get().getType())) {
875832
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
876-
RankedTensorType expandedOperandType =
877-
getExpandedType(opOperandType, indexingMap, expansionInfo);
833+
SmallVector<OpFoldResult> expandedOperandShape;
834+
RankedTensorType expandedOperandType;
835+
std::tie(expandedOperandShape, expandedOperandType) =
836+
getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
878837
if (expandedOperandType != opOperand->get().getType()) {
879838
// Reshape the operand to get the right type.
880839
SmallVector<ReassociationIndices> reassociation =
@@ -888,7 +847,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
888847
/*isExpandingReshape=*/true)))
889848
return std::nullopt;
890849
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
891-
loc, expandedOperandType, opOperand->get(), reassociation));
850+
loc, expandedOperandType, opOperand->get(), reassociation,
851+
expandedOperandShape));
892852
continue;
893853
}
894854
}
@@ -899,8 +859,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
899859
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
900860
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
901861
auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
902-
RankedTensorType expandedOutputType =
903-
getExpandedType(opOperandType, indexingMap, expansionInfo);
862+
SmallVector<OpFoldResult> expandedOutputShape;
863+
RankedTensorType expandedOutputType;
864+
std::tie(expandedOutputShape, expandedOutputType) =
865+
getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
904866
if (expandedOutputType != opOperand.get().getType()) {
905867
SmallVector<ReassociationIndices> reassociation =
906868
getReassociationForExpansion(indexingMap, expansionInfo);
@@ -913,7 +875,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
913875
/*isExpandingReshape=*/true)))
914876
return std::nullopt;
915877
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
916-
loc, expandedOutputType, opOperand.get(), reassociation));
878+
loc, expandedOutputType, opOperand.get(), reassociation,
879+
expandedOutputShape));
917880
} else {
918881
outputs.push_back(opOperand.get());
919882
}

0 commit comments

Comments
 (0)