Skip to content

[mlir][linalg] Enable fusion by expansion of reduction and named ops #83473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 58 additions & 50 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,10 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
// expanding the dimensionality of the elementwise operations.
//===---------------------------------------------------------------------===//

/// Conditions for folding a generic operation with a reshape op by expanding
/// the iteration space dimensionality for tensor operations. These are
/// preconditions assumed by `foldReshapeByDimExpansion` which implements the
/// following fusion pattern.
/// Conditions for folding a structured linalg operation with a reshape op by
/// expanding the iteration space dimensionality for tensor operations. These
/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
/// the following fusion pattern.
///
/// Consider
///
Expand All @@ -481,9 +481,9 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
///
/// The reshape can be folded into the `genericOp` if its loop dimensionality
/// The reshape can be folded into the `linalgOp` if its loop dimensionality
/// is increased to match the result (operand) of the tensor.expand_shape.
/// The indexing_map of the fused tensor in the `genericOp` and the
/// The indexing_map of the fused tensor in the `linalgOp` and the
/// reassociation map helps compute the indexing maps of the modified op.
/// For the above example, based on the reassociation map it
/// can be concluded that
Expand All @@ -502,7 +502,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
/// d1 -> e2, e3, e4
/// d2 -> e5
///
/// substituting this, the generic op can be rewritten as
/// substituting this, the structured op can be rewritten as
///
/// %d = linalg.generic ins(%0, %1 : )
/// indexing_maps =
Expand All @@ -520,23 +520,28 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
///
/// The added reshapes are again expanding patterns, so they will get fused
/// with its producers if possible.
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
OpOperand *fusableOpOperand) {
// Is fusable only if:
// - All the indexing maps for operands and results are projected
// permutations.
// - The fused tensor is not a scalar.
// - All the loops are parallel loops.
return genericOp.hasPureTensorSemantics() &&
llvm::all_of(genericOp.getIndexingMaps().getValue(),
// - All the loops for the reshaped operand are parallel loops.
SmallVector<utils::IteratorType> iteratorTypes =
linalgOp.getIteratorTypesArray();
AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
return linalgOp.hasPureTensorSemantics() &&
llvm::all_of(linalgOp.getIndexingMaps().getValue(),
[](Attribute attr) {
return cast<AffineMapAttr>(attr)
.getValue()
.isProjectedPermutation();
}) &&
genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
0 &&
llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
operandMap.getNumResults() > 0 &&
llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
return isParallelIterator(
iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
});
}

namespace {
Expand Down Expand Up @@ -628,10 +633,10 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
/// Note that this could be extended to handle dynamic case, but the
/// implementation below uses `affine.apply` which seems to have issues when the
/// shapes are not static.
static LogicalResult isGenericOpExpandable(GenericOp genericOp,
const ExpansionInfo &expansionInfo,
PatternRewriter &rewriter) {
if (!genericOp.hasIndexSemantics())
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
const ExpansionInfo &expansionInfo,
PatternRewriter &rewriter) {
if (!linalgOp.hasIndexSemantics())
return success();
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
Expand All @@ -640,7 +645,7 @@ static LogicalResult isGenericOpExpandable(GenericOp genericOp,
for (int64_t shape : expandedShape.drop_front()) {
if (ShapedType::isDynamic(shape)) {
return rewriter.notifyMatchFailure(
genericOp, "cannot expand due to index semantics and dynamic dims");
linalgOp, "cannot expand due to index semantics and dynamic dims");
}
}
}
Expand Down Expand Up @@ -749,10 +754,10 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
/// that those conditions have been satisfied.
static std::optional<SmallVector<Value>>
fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
OpOperand *fusableOpOperand,
PatternRewriter &rewriter) {
assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
"preconditions for fuse operation failed");
// Check if reshape is expanding or collapsing.
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
Expand All @@ -767,35 +772,35 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,

ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(
genericOp, fusableOpOperand,
linalgOp, fusableOpOperand,
isExpanding ? expandingReshapeOp.getReassociationMaps()
: collapsingReshapeOp.getReassociationMaps(),
expandedType.getShape(), collapsedType.getShape(), rewriter)))
return std::nullopt;

if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
return std::nullopt;

SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap m) {
llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
}));

// Set insertion point to the generic op.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(genericOp);
rewriter.setInsertionPoint(linalgOp);

SmallVector<Value> expandedOpOperands;
expandedOpOperands.reserve(genericOp.getNumDpsInputs());
for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
if (opOperand == fusableOpOperand) {
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
: collapsingReshapeOp.getSrc());
continue;
}
if (auto opOperandType =
dyn_cast<RankedTensorType>(opOperand->get().getType())) {
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
RankedTensorType expandedOperandType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
if (expandedOperandType != opOperand->get().getType()) {
Expand All @@ -804,25 +809,25 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
getReassociationForExpansion(indexingMap, expansionInfo);
if (failed(reshapeLikeShapesAreCompatible(
[&](const Twine &msg) {
return rewriter.notifyMatchFailure(genericOp, msg);
return rewriter.notifyMatchFailure(linalgOp, msg);
},
opOperandType.getShape(), expandedOperandType.getShape(),
reassociation,
/*isExpandingReshape=*/true)))
return std::nullopt;
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
genericOp.getLoc(), expandedOperandType, opOperand->get(),
linalgOp.getLoc(), expandedOperandType, opOperand->get(),
reassociation));
continue;
}
}
expandedOpOperands.push_back(opOperand->get());
}

Location loc = genericOp.getLoc();
Location loc = linalgOp.getLoc();
SmallVector<Value> outputs;
for (OpOperand &opOperand : genericOp.getDpsInitsMutable()) {
AffineMap indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
RankedTensorType expandedOutputType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
Expand All @@ -831,14 +836,14 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
getReassociationForExpansion(indexingMap, expansionInfo);
if (failed(reshapeLikeShapesAreCompatible(
[&](const Twine &msg) {
return rewriter.notifyMatchFailure(genericOp, msg);
return rewriter.notifyMatchFailure(linalgOp, msg);
},
opOperandType.getShape(), expandedOutputType.getShape(),
reassociation,
/*isExpandingReshape=*/true)))
return std::nullopt;
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
genericOp.getLoc(), expandedOutputType, opOperand.get(),
linalgOp.getLoc(), expandedOutputType, opOperand.get(),
reassociation));
} else {
outputs.push_back(opOperand.get());
Expand All @@ -848,14 +853,17 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
// The iterator types of the expanded op are all parallel.
SmallVector<utils::IteratorType> iteratorTypes(
expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
for (auto j : expansionInfo.getExpandedDims(i))
iteratorTypes[j] = type;

TypeRange resultTypes = ValueRange(outputs).getTypes();
auto fusedOp =
rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
/*inputs=*/expandedOpOperands, outputs,
expandedOpIndexingMaps, iteratorTypes);
Region &fusedRegion = fusedOp->getRegion(0);
Region &originalRegion = genericOp->getRegion(0);
Region &originalRegion = linalgOp->getRegion(0);
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());

// Update the index accesses after the expansion.
Expand All @@ -864,16 +872,16 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
for (OpResult opResult : genericOp->getOpResults()) {
for (OpResult opResult : linalgOp->getOpResults()) {
int64_t resultNumber = opResult.getResultNumber();
if (resultTypes[resultNumber] != opResult.getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(
genericOp.getMatchingIndexingMap(
genericOp.getDpsInitOperand(resultNumber)),
linalgOp.getMatchingIndexingMap(
linalgOp.getDpsInitOperand(resultNumber)),
expansionInfo);
resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
genericOp.getLoc(), opResult.getType(),
linalgOp.getLoc(), opResult.getType(),
fusedOp->getResult(resultNumber), reassociation));
} else {
resultVals.push_back(fusedOp->getResult(resultNumber));
Expand All @@ -885,37 +893,37 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,

namespace {

/// Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
/// when the reshape op is collapsing dimensions. The dimensionality of the loop
/// in the consumer is expanded.
class FoldWithProducerReshapeOpByExpansion
: public OpRewritePattern<GenericOp> {
: public OpInterfaceRewritePattern<LinalgOp> {
public:
FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
LogicalResult matchAndRewrite(LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
tensor::CollapseShapeOp reshapeOp =
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
continue;
// Fold only if
// - The tensor reshape op is folding.
// - All constraints of fusing with reshape by expansion are met.
if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
(!controlFoldingReshapes(opOperand)))
continue;

std::optional<SmallVector<Value>> replacementValues =
fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
if (!replacementValues)
return failure();
rewriter.replaceOp(genericOp, *replacementValues);
rewriter.replaceOp(linalgOp, *replacementValues);
return success();
}
return failure();
Expand Down Expand Up @@ -945,7 +953,7 @@ struct FoldReshapeWithGenericOpByExpansion
"source not produced by an operation");
}

auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
if (!producer) {
return rewriter.notifyMatchFailure(reshapeOp,
"producer not a generic op");
Expand Down
Loading