Skip to content

[mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. #127943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 78 additions & 111 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include <optional>
#include <utility>

Expand Down Expand Up @@ -590,44 +591,45 @@ class ExpansionInfo {
// the expanded op.
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<int64_t> expandedShape,
ArrayRef<int64_t> collapsedShape,
ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter);
unsigned getOrigOpNumDims() const { return reassociation.size(); }
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
ReassociationIndicesRef getExpandedDims(unsigned i) const {
return reassociation[i];
}
ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
return expandedShapeMap[i];
}
ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }

private:
/// Reassociation from the dimensions in the original operation to the
/// dimension of the expanded operation.
SmallVector<ReassociationIndices> reassociation;
/// Mapping from extent of loops in the original operation, to the extent of
/// loops in the expanded operation.
SmallVector<SmallVector<int64_t>> expandedShapeMap;
SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
/// Extent of the loop in the original operation.
SmallVector<int64_t> originalLoopExtent;
SmallVector<OpFoldResult> originalLoopExtent;
unsigned expandedOpNumDims;
};
} // namespace

LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<int64_t> expandedShape,
ArrayRef<int64_t> collapsedShape,
ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);

SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(linalgOp);
originalLoopExtent = llvm::map_to_vector(
linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
[](Range r) { return r.size; });

reassociation.clear();
expandedShapeMap.clear();
Expand All @@ -639,7 +641,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
numExpandedDims[pos] = foldedDims.getNumResults();
ArrayRef<int64_t> shape =
ArrayRef<OpFoldResult> shape =
expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
expandedShapeMap[pos].assign(shape.begin(), shape.end());
}
Expand All @@ -660,33 +662,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
return success();
}

/// Expanding the body of a linalg operation requires adaptations of the
/// accessed loop indices. Specifically, access of indices in the original
/// operation need to be replaced with linearizations of indices in the expanded
/// op. That requires the shape of the expanded dimensions to be static (at
/// least all but the most significant). For now check that these are all
/// statically sized. Note that this could be extended to handle dynamic case,
/// but the implementation below uses `affine.apply` which seems to have issues
/// when the shapes are not static.
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
const ExpansionInfo &expansionInfo,
PatternRewriter &rewriter) {
if (!linalgOp.hasIndexSemantics())
return success();
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
if (expandedShape.size() == 1)
continue;
for (int64_t shape : expandedShape.drop_front()) {
if (ShapedType::isDynamic(shape)) {
return rewriter.notifyMatchFailure(
linalgOp, "cannot expand due to index semantics and dynamic dims");
}
}
}
return success();
}

/// Return the indexing map to use in the expanded op for a given the
/// `indexingMap` of the original operation.
static AffineMap
Expand All @@ -706,18 +681,23 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
builder.getContext());
}

/// Return the type of the operand/result to use in the expanded op given the
/// type in the original op.
static RankedTensorType getExpandedType(RankedTensorType originalType,
AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
SmallVector<int64_t> expandedShape;
/// Return the shape and type of the operand/result to use in the expanded op
/// given the type in the original op.
static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
SmallVector<OpFoldResult> expandedShape;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
ArrayRef<OpFoldResult> dimExpansion =
expansionInfo.getExpandedShapeOfDim(dim);
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
}
return RankedTensorType::get(expandedShape, originalType.getElementType());
SmallVector<int64_t> expandedStaticShape;
std::tie(expandedStaticShape, std::ignore) =
decomposeMixedValues(expandedShape);
return {expandedShape, RankedTensorType::get(expandedStaticShape,
originalType.getElementType())};
}

/// Returns the reassociation maps to use in the `tensor.expand_shape`
Expand Down Expand Up @@ -765,49 +745,28 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
// Linearize the expanded indices of the original index dimension.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(indexOp);
ArrayRef<int64_t> expandedDimsShape =
ArrayRef<OpFoldResult> expandedDimsShape =
expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
SmallVector<Value> expandedIndices;
expandedIndices.reserve(expandedDims.size() - 1);
llvm::transform(
expandedDims.drop_front(), std::back_inserter(expandedIndices),
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
assert(!ShapedType::isDynamic(std::get<0>(it)));
AffineExpr idx, acc;
OpFoldResult newIndex =
rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
for (auto [expandedShape, expandedIndex] :
llvm::zip(expandedDimsShape, expandedIndices)) {
AffineExpr idx, acc, shape;
bindDims(rewriter.getContext(), idx, acc);
newIndex = rewriter.create<affine::AffineApplyOp>(
indexOp.getLoc(), idx + acc * std::get<0>(it),
ValueRange{std::get<1>(it), newIndex});
}
rewriter.replaceOp(indexOp, newIndex);
}
}

/// Checks if a single dynamic dimension expanded into multiple dynamic
/// dimensions.
static LogicalResult
validateDynamicDimExpansion(LinalgOp linalgOp,
const ExpansionInfo &expansionInfo,
PatternRewriter &rewriter) {
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
if (expandedShape.size() == 1)
continue;
bool foundDynamic = false;
for (int64_t shape : expandedShape) {
if (!ShapedType::isDynamic(shape))
continue;
if (foundDynamic) {
return rewriter.notifyMatchFailure(
linalgOp, "cannot infer expanded shape with multiple dynamic "
"dims in the same reassociation group");
}
foundDynamic = true;
bindSymbols(rewriter.getContext(), shape);
newIndex = affine::makeComposedFoldedAffineApply(
rewriter, indexOp.getLoc(), idx + acc * shape,
ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
}
Value newIndexVal =
getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
rewriter.replaceOp(indexOp, newIndexVal);
}
return success();
}

// Create an expanded transpose op.
Expand Down Expand Up @@ -910,31 +869,34 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
"preconditions for fuse operation failed");

Location loc = linalgOp.getLoc();
// Check if reshape is expanding or collapsing.
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
bool isExpanding = (expandingReshapeOp != nullptr);
RankedTensorType expandedType = isExpanding
? expandingReshapeOp.getResultType()
: collapsingReshapeOp.getSrcType();
RankedTensorType collapsedType = isExpanding
? expandingReshapeOp.getSrcType()
: collapsingReshapeOp.getResultType();
SmallVector<OpFoldResult> expandedShape, collapsedShape;
SmallVector<AffineMap, 4> reassociationIndices;
Value src;
if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
// Try to move the dynamic dimensions in output shape before the `linalgOp`
// to maintain SSA validity
if (failed(moveValueDefinitions(
rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
return std::nullopt;
Comment on lines +878 to +880
Copy link
Contributor

@IanWood1 IanWood1 Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be problematic because after moveValueDefinitions mutates the IR, the rewrite pattern is no longer allowed to return failure().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of reshapeLikeShapesAreCompatible has a similar problem below (prior to this PR). So maybe this isn't that big of a deal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, many places that happens in the IR. I dont have a good solution for this though. In general its just changing position of some tensor.dim operations which shouldnt matter too much.


expandedShape = expandingReshapeOp.getMixedOutputShape();
reassociationIndices = expandingReshapeOp.getReassociationMaps();
src = expandingReshapeOp.getSrc();
} else {
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
if (!collapsingReshapeOp)
return std::nullopt;

expandedShape = tensor::getMixedSizes(
rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
reassociationIndices = collapsingReshapeOp.getReassociationMaps();
src = collapsingReshapeOp.getSrc();
}

ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(
linalgOp, fusableOpOperand,
isExpanding ? expandingReshapeOp.getReassociationMaps()
: collapsingReshapeOp.getReassociationMaps(),
expandedType.getShape(), collapsedType.getShape(), rewriter)))
return std::nullopt;

// TODO: With the support of multiple dynamic dims expansion in
// tensor.expand_shape op, this case can be handled.
if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
return std::nullopt;

if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
reassociationIndices, expandedShape,
rewriter)))
return std::nullopt;

SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
Expand All @@ -950,15 +912,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
if (opOperand == fusableOpOperand) {
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
: collapsingReshapeOp.getSrc());
expandedOpOperands.push_back(src);
continue;
}
if (auto opOperandType =
dyn_cast<RankedTensorType>(opOperand->get().getType())) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
RankedTensorType expandedOperandType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
SmallVector<OpFoldResult> expandedOperandShape;
RankedTensorType expandedOperandType;
std::tie(expandedOperandShape, expandedOperandType) =
getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOperandType != opOperand->get().getType()) {
// Reshape the operand to get the right type.
SmallVector<ReassociationIndices> reassociation =
Expand All @@ -972,7 +935,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
loc, expandedOperandType, opOperand->get(), reassociation));
loc, expandedOperandType, opOperand->get(), reassociation,
expandedOperandShape));
continue;
}
}
Expand All @@ -983,8 +947,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
RankedTensorType expandedOutputType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
SmallVector<OpFoldResult> expandedOutputShape;
RankedTensorType expandedOutputType;
std::tie(expandedOutputShape, expandedOutputType) =
getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOutputType != opOperand.get().getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
Expand All @@ -997,7 +963,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
loc, expandedOutputType, opOperand.get(), reassociation));
loc, expandedOutputType, opOperand.get(), reassociation,
expandedOutputShape));
} else {
outputs.push_back(opOperand.get());
}
Expand Down
Loading