-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. #127943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Support/LLVM.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "mlir/Transforms/RegionUtils.h" | ||
#include <optional> | ||
#include <utility> | ||
|
||
|
@@ -590,44 +591,45 @@ class ExpansionInfo { | |
// the expanded op. | ||
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, | ||
ArrayRef<AffineMap> reassociationMaps, | ||
ArrayRef<int64_t> expandedShape, | ||
ArrayRef<int64_t> collapsedShape, | ||
ArrayRef<OpFoldResult> expandedShape, | ||
PatternRewriter &rewriter); | ||
unsigned getOrigOpNumDims() const { return reassociation.size(); } | ||
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } | ||
ReassociationIndicesRef getExpandedDims(unsigned i) const { | ||
return reassociation[i]; | ||
} | ||
ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { | ||
ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const { | ||
return expandedShapeMap[i]; | ||
} | ||
ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; } | ||
ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; } | ||
|
||
private: | ||
/// Reassociation from the dimensions in the original operation to the | ||
/// dimension of the expanded operation. | ||
SmallVector<ReassociationIndices> reassociation; | ||
/// Mapping from extent of loops in the original operation, to the extent of | ||
/// loops in the expanded operation. | ||
SmallVector<SmallVector<int64_t>> expandedShapeMap; | ||
SmallVector<SmallVector<OpFoldResult>> expandedShapeMap; | ||
/// Extent of the loop in the original operation. | ||
SmallVector<int64_t> originalLoopExtent; | ||
SmallVector<OpFoldResult> originalLoopExtent; | ||
unsigned expandedOpNumDims; | ||
}; | ||
} // namespace | ||
|
||
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, | ||
OpOperand *fusableOpOperand, | ||
ArrayRef<AffineMap> reassociationMaps, | ||
ArrayRef<int64_t> expandedShape, | ||
ArrayRef<int64_t> collapsedShape, | ||
ArrayRef<OpFoldResult> expandedShape, | ||
PatternRewriter &rewriter) { | ||
if (reassociationMaps.empty()) | ||
return failure(); | ||
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand); | ||
|
||
SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges(); | ||
originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end()); | ||
OpBuilder::InsertionGuard g(rewriter); | ||
rewriter.setInsertionPoint(linalgOp); | ||
originalLoopExtent = llvm::map_to_vector( | ||
linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()), | ||
[](Range r) { return r.size; }); | ||
|
||
reassociation.clear(); | ||
expandedShapeMap.clear(); | ||
|
@@ -639,7 +641,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, | |
unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition(); | ||
AffineMap foldedDims = reassociationMaps[resultExpr.index()]; | ||
numExpandedDims[pos] = foldedDims.getNumResults(); | ||
ArrayRef<int64_t> shape = | ||
ArrayRef<OpFoldResult> shape = | ||
expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); | ||
expandedShapeMap[pos].assign(shape.begin(), shape.end()); | ||
} | ||
|
@@ -660,33 +662,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, | |
return success(); | ||
} | ||
|
||
/// Expanding the body of a linalg operation requires adaptations of the | ||
/// accessed loop indices. Specifically, access of indices in the original | ||
/// operation need to be replaced with linearizations of indices in the expanded | ||
/// op. That requires the shape of the expanded dimensions to be static (at | ||
/// least all but the most significant). For now check that these are all | ||
/// statically sized. Note that this could be extended to handle dynamic case, | ||
/// but the implementation below uses `affine.apply` which seems to have issues | ||
/// when the shapes are not static. | ||
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp, | ||
const ExpansionInfo &expansionInfo, | ||
PatternRewriter &rewriter) { | ||
if (!linalgOp.hasIndexSemantics()) | ||
return success(); | ||
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { | ||
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); | ||
if (expandedShape.size() == 1) | ||
continue; | ||
for (int64_t shape : expandedShape.drop_front()) { | ||
if (ShapedType::isDynamic(shape)) { | ||
return rewriter.notifyMatchFailure( | ||
linalgOp, "cannot expand due to index semantics and dynamic dims"); | ||
} | ||
} | ||
} | ||
return success(); | ||
} | ||
|
||
/// Return the indexing map to use in the expanded op for a given the | ||
/// `indexingMap` of the original operation. | ||
static AffineMap | ||
|
@@ -706,18 +681,23 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, | |
builder.getContext()); | ||
} | ||
|
||
/// Return the type of the operand/result to use in the expanded op given the | ||
/// type in the original op. | ||
static RankedTensorType getExpandedType(RankedTensorType originalType, | ||
AffineMap indexingMap, | ||
const ExpansionInfo &expansionInfo) { | ||
SmallVector<int64_t> expandedShape; | ||
/// Return the shape and type of the operand/result to use in the expanded op | ||
/// given the type in the original op. | ||
static std::tuple<SmallVector<OpFoldResult>, RankedTensorType> | ||
getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, | ||
const ExpansionInfo &expansionInfo) { | ||
SmallVector<OpFoldResult> expandedShape; | ||
for (AffineExpr expr : indexingMap.getResults()) { | ||
unsigned dim = cast<AffineDimExpr>(expr).getPosition(); | ||
auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); | ||
ArrayRef<OpFoldResult> dimExpansion = | ||
expansionInfo.getExpandedShapeOfDim(dim); | ||
expandedShape.append(dimExpansion.begin(), dimExpansion.end()); | ||
} | ||
return RankedTensorType::get(expandedShape, originalType.getElementType()); | ||
SmallVector<int64_t> expandedStaticShape; | ||
std::tie(expandedStaticShape, std::ignore) = | ||
decomposeMixedValues(expandedShape); | ||
return {expandedShape, RankedTensorType::get(expandedStaticShape, | ||
originalType.getElementType())}; | ||
} | ||
|
||
/// Returns the reassociation maps to use in the `tensor.expand_shape` | ||
|
@@ -765,49 +745,28 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, | |
// Linearize the expanded indices of the original index dimension. | ||
OpBuilder::InsertionGuard guard(rewriter); | ||
rewriter.setInsertionPointAfter(indexOp); | ||
ArrayRef<int64_t> expandedDimsShape = | ||
ArrayRef<OpFoldResult> expandedDimsShape = | ||
expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front(); | ||
SmallVector<Value> expandedIndices; | ||
expandedIndices.reserve(expandedDims.size() - 1); | ||
llvm::transform( | ||
expandedDims.drop_front(), std::back_inserter(expandedIndices), | ||
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); | ||
Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); | ||
for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { | ||
assert(!ShapedType::isDynamic(std::get<0>(it))); | ||
AffineExpr idx, acc; | ||
OpFoldResult newIndex = | ||
rewriter.create<IndexOp>(loc, expandedDims.front()).getResult(); | ||
for (auto [expandedShape, expandedIndex] : | ||
llvm::zip(expandedDimsShape, expandedIndices)) { | ||
AffineExpr idx, acc, shape; | ||
bindDims(rewriter.getContext(), idx, acc); | ||
newIndex = rewriter.create<affine::AffineApplyOp>( | ||
indexOp.getLoc(), idx + acc * std::get<0>(it), | ||
ValueRange{std::get<1>(it), newIndex}); | ||
} | ||
rewriter.replaceOp(indexOp, newIndex); | ||
} | ||
} | ||
|
||
/// Checks if a single dynamic dimension expanded into multiple dynamic | ||
/// dimensions. | ||
static LogicalResult | ||
validateDynamicDimExpansion(LinalgOp linalgOp, | ||
const ExpansionInfo &expansionInfo, | ||
PatternRewriter &rewriter) { | ||
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { | ||
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); | ||
if (expandedShape.size() == 1) | ||
continue; | ||
bool foundDynamic = false; | ||
for (int64_t shape : expandedShape) { | ||
if (!ShapedType::isDynamic(shape)) | ||
continue; | ||
if (foundDynamic) { | ||
return rewriter.notifyMatchFailure( | ||
linalgOp, "cannot infer expanded shape with multiple dynamic " | ||
"dims in the same reassociation group"); | ||
} | ||
foundDynamic = true; | ||
bindSymbols(rewriter.getContext(), shape); | ||
newIndex = affine::makeComposedFoldedAffineApply( | ||
rewriter, indexOp.getLoc(), idx + acc * shape, | ||
ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape}); | ||
} | ||
Value newIndexVal = | ||
getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex); | ||
rewriter.replaceOp(indexOp, newIndexVal); | ||
} | ||
return success(); | ||
} | ||
|
||
// Create an expanded transpose op. | ||
|
@@ -910,31 +869,34 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, | |
"preconditions for fuse operation failed"); | ||
|
||
Location loc = linalgOp.getLoc(); | ||
// Check if reshape is expanding or collapsing. | ||
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); | ||
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); | ||
bool isExpanding = (expandingReshapeOp != nullptr); | ||
RankedTensorType expandedType = isExpanding | ||
? expandingReshapeOp.getResultType() | ||
: collapsingReshapeOp.getSrcType(); | ||
RankedTensorType collapsedType = isExpanding | ||
? expandingReshapeOp.getSrcType() | ||
: collapsingReshapeOp.getResultType(); | ||
SmallVector<OpFoldResult> expandedShape, collapsedShape; | ||
SmallVector<AffineMap, 4> reassociationIndices; | ||
Value src; | ||
if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) { | ||
// Try to move the dynamic dimensions in output shape before the `linalgOp` | ||
// to maintain SSA validity | ||
if (failed(moveValueDefinitions( | ||
rewriter, expandingReshapeOp.getOutputShape(), linalgOp))) | ||
return std::nullopt; | ||
Comment on lines
+878
to
+880
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might be problematic because after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The usage of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, many places that happens in the IR. I dont have a good solution for this though. In general its just changing position of some |
||
|
||
expandedShape = expandingReshapeOp.getMixedOutputShape(); | ||
reassociationIndices = expandingReshapeOp.getReassociationMaps(); | ||
src = expandingReshapeOp.getSrc(); | ||
} else { | ||
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp); | ||
MaheshRavishankar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (!collapsingReshapeOp) | ||
return std::nullopt; | ||
|
||
expandedShape = tensor::getMixedSizes( | ||
rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc()); | ||
reassociationIndices = collapsingReshapeOp.getReassociationMaps(); | ||
src = collapsingReshapeOp.getSrc(); | ||
} | ||
|
||
ExpansionInfo expansionInfo; | ||
if (failed(expansionInfo.compute( | ||
linalgOp, fusableOpOperand, | ||
isExpanding ? expandingReshapeOp.getReassociationMaps() | ||
: collapsingReshapeOp.getReassociationMaps(), | ||
expandedType.getShape(), collapsedType.getShape(), rewriter))) | ||
return std::nullopt; | ||
|
||
// TODO: With the support of multiple dynamic dims expansion in | ||
// tensor.expand_shape op, this case can be handled. | ||
if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter))) | ||
return std::nullopt; | ||
|
||
if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter))) | ||
if (failed(expansionInfo.compute(linalgOp, fusableOpOperand, | ||
reassociationIndices, expandedShape, | ||
rewriter))) | ||
return std::nullopt; | ||
|
||
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( | ||
|
@@ -950,15 +912,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, | |
expandedOpOperands.reserve(linalgOp.getNumDpsInputs()); | ||
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) { | ||
if (opOperand == fusableOpOperand) { | ||
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc() | ||
: collapsingReshapeOp.getSrc()); | ||
expandedOpOperands.push_back(src); | ||
continue; | ||
} | ||
if (auto opOperandType = | ||
dyn_cast<RankedTensorType>(opOperand->get().getType())) { | ||
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); | ||
RankedTensorType expandedOperandType = | ||
getExpandedType(opOperandType, indexingMap, expansionInfo); | ||
SmallVector<OpFoldResult> expandedOperandShape; | ||
RankedTensorType expandedOperandType; | ||
std::tie(expandedOperandShape, expandedOperandType) = | ||
getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo); | ||
if (expandedOperandType != opOperand->get().getType()) { | ||
// Reshape the operand to get the right type. | ||
SmallVector<ReassociationIndices> reassociation = | ||
|
@@ -972,7 +935,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, | |
/*isExpandingReshape=*/true))) | ||
return std::nullopt; | ||
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( | ||
loc, expandedOperandType, opOperand->get(), reassociation)); | ||
loc, expandedOperandType, opOperand->get(), reassociation, | ||
expandedOperandShape)); | ||
continue; | ||
} | ||
} | ||
|
@@ -983,8 +947,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, | |
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) { | ||
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); | ||
auto opOperandType = cast<RankedTensorType>(opOperand.get().getType()); | ||
RankedTensorType expandedOutputType = | ||
getExpandedType(opOperandType, indexingMap, expansionInfo); | ||
SmallVector<OpFoldResult> expandedOutputShape; | ||
RankedTensorType expandedOutputType; | ||
std::tie(expandedOutputShape, expandedOutputType) = | ||
getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo); | ||
if (expandedOutputType != opOperand.get().getType()) { | ||
SmallVector<ReassociationIndices> reassociation = | ||
getReassociationForExpansion(indexingMap, expansionInfo); | ||
|
@@ -997,7 +963,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, | |
/*isExpandingReshape=*/true))) | ||
return std::nullopt; | ||
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( | ||
loc, expandedOutputType, opOperand.get(), reassociation)); | ||
loc, expandedOutputType, opOperand.get(), reassociation, | ||
expandedOutputShape)); | ||
} else { | ||
outputs.push_back(opOperand.get()); | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.