Skip to content

[mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp #68526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1047,16 +1047,18 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
ArrayRef<ReassociationIndices> dimSequences);

/// Collapses dimensions of linalg.generic operation. A precondition to
/// calling this method is that for each list in `foldedIterationDim`, the
/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
/// to calling this method is that for each list in `foldedIterationDim`, the
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
/// When valid, the method also collapses the operands of the op. Returns
/// replacement values of the results of the original `genericOp` by inserting
/// replacement values of the results of the original `linalgOp` by inserting
/// reshapes to get back values of compatible types.
FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);
template <typename LinalgType>
FailureOr<SmallVector<Value>>
collapseOpIterationDims(LinalgType op,
ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);

struct LowerPackResult {
tensor::PadOp padOp;
Expand Down Expand Up @@ -1515,7 +1517,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
/// to return an array of `ReassociationIndices` representing dimensions that
/// should be merged.
using GetCollapsableDimensionsFn =
std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;

/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
/// tensor operands when needed and expand back the result tensors.
Expand Down
187 changes: 103 additions & 84 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1373,16 +1373,17 @@ getOperandReassociation(AffineMap indexingMap,
}

/// Get the new value to use for a given `OpOperand` in the collapsed operation.
static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
static Value getCollapsedOpOperand(Location loc, LinalgOp op,
OpOperand *opOperand,
const CollapsingInfo &collapsingInfo,
OpBuilder &builder) {
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
SmallVector<ReassociationIndices> operandReassociation =
getOperandReassociation(indexingMap, collapsingInfo);

// If the number of entries in the reassocation for the operand is same as the
// number of results of the indexing map, then nothing to do for this operand.
// If the number of entries in the reassociation for the operand is same as
// the number of results of the indexing map, then nothing to do for this
// operand.
Value operand = opOperand->get();
if (operandReassociation.size() == indexingMap.getNumResults())
return operand;
Expand Down Expand Up @@ -1439,41 +1440,100 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
}
}

template <typename LinalgType>
Operation *createCollapsedOp(LinalgType op,
const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter) {
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_asserts create problems in deployment. This is an optimization, not required for correctness. It is completely reasonable to have an operation not collapse and optimization to return a failure. The caller can then handle appropriately. Please change to fail gracefully.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't understand it.
Whoever is planning to call this function can know in compile time that the type it uses as template is/isn't supported.
If we can catch bug in compilation time it is better than catch it during runtime.
Can you elaborate about static assert can cause deployment issues downstream?

"unsupported linalg op type to create");
Location loc = op->getLoc();

// Get the input operands.
SmallVector<Value> inputOperands =
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
rewriter);
});

// Get the output operands and result types.
SmallVector<Type> resultTypes;
SmallVector<Value> outputOperands;
resultTypes.reserve(op.getNumDpsInits());
outputOperands.reserve(op.getNumDpsInits());
for (OpOperand &output : op.getDpsInitsMutable()) {
Value newOutput =
getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
// If the op has "buffer semantics", then the init operands are ranked
// memrefs and the op has no results.
if (!op.hasBufferSemantics())
resultTypes.push_back(newOutput.getType());
}

if (isa<linalg::CopyOp>(op)) {
return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
outputOperands[0]);
}

// Get the iterator types for the operand.
SmallVector<utils::IteratorType> iteratorTypes =
getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);

// Get the indexing maps.
auto indexingMaps =
llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
});

Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
loc, resultTypes, inputOperands, outputOperands, indexingMaps,
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
Block *origOpBlock = &op->getRegion(0).front();
Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
collapsedOpBlock->getArguments());

return collapsedOp;
}

/// Implementation of fusion with reshape operation by collapsing dimensions.
FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
template <typename LinalgType>
FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter) {
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
"unsupported linalg op type to collapse");

// Bail on trivial no-op cases.
if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
return foldedDims.size() <= 1;
}))
return failure();

bool hasBufferSemantics = genericOp.hasBufferSemantics();
bool hasBufferSemantics = op.hasBufferSemantics();
if (hasBufferSemantics &&
!llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
!llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
if (!memRefToCollapse)
return true;

return memref::CollapseShapeOp::isGuaranteedCollapsible(
memRefToCollapse, foldedIterationDims);
}))
return rewriter.notifyMatchFailure(genericOp,
return rewriter.notifyMatchFailure(op,
"memref is not guaranteed collapsible");

CollapsingInfo collapsingInfo;
if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
foldedIterationDims))) {
if (failed(
collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
return rewriter.notifyMatchFailure(
genericOp, "illegal to collapse specified dimensions");
op, "illegal to collapse specified dimensions");
}

// Bail on non-canonical ranges.
SmallVector<Range> loopRanges =
cast<LinalgOp>(genericOp.getOperation())
.createLoopRanges(rewriter, genericOp.getLoc());
cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
return cast<IntegerAttr>(attr).getInt() == value;
Expand All @@ -1486,78 +1546,36 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
opFoldIsConstantValue(range.stride, 1);
})) {
return rewriter.notifyMatchFailure(
genericOp,
"expected all loop ranges to have zero start and unit stride");
op, "expected all loop ranges to have zero start and unit stride");
}

// Get the iterator types for the operand.
SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
genericOp.getIteratorTypesArray(), collapsingInfo);

// Get the indexing maps.
auto indexingMaps = llvm::to_vector(
llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
}));

Location loc = genericOp->getLoc();

// Get the input operands.
auto inputOperands = llvm::to_vector(llvm::map_range(
genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
rewriter);
}));

// Get the output operands and result types.
SmallVector<Type> resultTypes;
SmallVector<Value> outputOperands;
resultTypes.reserve(genericOp.getNumDpsInits());
outputOperands.reserve(genericOp.getNumDpsInits());
for (OpOperand &output : genericOp.getDpsInitsMutable()) {
Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
// If the op has "buffer semantics", then the init operands are ranked
// memrefs and the op has no results.
if (!hasBufferSemantics)
resultTypes.push_back(newOutput.getType());
}

// Create the generic op.
auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
loc, resultTypes, inputOperands, outputOperands, indexingMaps,
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
Block *origOpBlock = &genericOp->getRegion(0).front();
Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
collapsedOpBlock->getArguments());
LinalgType collapsedOp = cast<LinalgType>(
createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));

if (collapsedGenericOp.hasIndexSemantics()) {
Location loc = op->getLoc();
if (collapsedOp.hasIndexSemantics()) {
// Collect the loop range of the generic op.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(collapsedGenericOp);
rewriter.setInsertionPoint(collapsedOp);
SmallVector<Value> loopBound =
llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
llvm::map_to_vector(loopRanges, [&](Range range) {
return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
}));
generateCollapsedIndexingRegion(loc,
&collapsedGenericOp->getRegion(0).front(),
});
generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
collapsingInfo, loopBound, rewriter);
}

// Insert expanding reshape for the result to get back the original result
// type.
SmallVector<Value> results;
for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
Value collapsedOpResult =
collapsedGenericOp->getResult(originalResult.index());
for (const auto &originalResult : llvm::enumerate(op->getResults())) {
Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
auto originalResultType =
cast<ShapedType>(originalResult.value().getType());
auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
AffineMap indexingMap =
genericOp.getIndexingMapMatchingResult(originalResult.value());
op.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
if (isa<MemRefType>(collapsedOpResult.getType())) {
Expand Down Expand Up @@ -1606,8 +1624,8 @@ class FoldWithProducerReshapeOpByCollapsing
}

std::optional<SmallVector<Value>> replacements =
collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
rewriter);
collapseOpIterationDims<linalg::GenericOp>(
genericOp, collapsableIterationDims, rewriter);
if (!replacements) {
return rewriter.notifyMatchFailure(
genericOp, "failed to do the fusion by collapsing transformation");
Expand All @@ -1624,36 +1642,36 @@ class FoldWithProducerReshapeOpByCollapsing
};

/// Pattern to collapse dimensions.
class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
public:
CollapseLinalgDimensions(MLIRContext *context,
GetCollapsableDimensionsFn collapseDimensions,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
: OpRewritePattern<LinalgType>(context, benefit),
controlCollapseDimension(std::move(collapseDimensions)) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
LogicalResult matchAndRewrite(LinalgType op,
PatternRewriter &rewriter) const override {
SmallVector<ReassociationIndices> collapsableIterationDims =
controlCollapseDimension(genericOp);
controlCollapseDimension(op);
if (collapsableIterationDims.empty())
return failure();

// Check if the specified list of dimensions to collapse is a valid list.
if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
collapsableIterationDims)) {
return rewriter.notifyMatchFailure(
genericOp, "specified dimensions cannot be collapsed");
op, "specified dimensions cannot be collapsed");
}

std::optional<SmallVector<Value>> replacements =
collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
rewriter);
collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
rewriter);
if (!replacements) {
return rewriter.notifyMatchFailure(genericOp,
"failed to collapse dimensions");
return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
}
rewriter.replaceOp(genericOp, *replacements);
rewriter.replaceOp(op, *replacements);
return success();
}

Expand Down Expand Up @@ -1884,8 +1902,9 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
void mlir::linalg::populateCollapseDimensions(
RewritePatternSet &patterns,
const GetCollapsableDimensionsFn &controlCollapseDimensions) {
patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
controlCollapseDimensions);
patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
CollapseLinalgDimensions<linalg::CopyOp>>(
patterns.getContext(), controlCollapseDimensions);
}

//===---------------------------------------------------------------------===//
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/Dialect/Linalg/collapse-dim.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,40 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
}
return %alloc : memref<2x6x24x48xi32>
}

// -----

// CHECK-LABEL: func.func @linalg_copy(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
// CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
// CHECK: }

func.func @linalg_copy(
%arg0: tensor<1x2x3x4x5xf32, 1>, %arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> {
%0 = linalg.copy ins(%arg0: tensor<1x2x3x4x5xf32, 1>) outs(%arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3>
return %0 : tensor<1x2x3x4x5xf32, 3>
}

// -----

// CHECK-LABEL: func.func private @memref_linalg_copy(
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32, 1>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32, 1>) {
// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
// CHECK: linalg.copy ins(%[[VAL_2]] : memref<1x24x256xf32, 1>) outs(%[[VAL_3]] : memref<1x24x256xf32, 1>)
// CHECK: return
// CHECK: }

func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: memref<1x24x32x8xf32, 1>) {
linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>)
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ struct TestLinalgElementwiseFusion
SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
collapseDimensions.end());
linalg::GetCollapsableDimensionsFn collapseFn =
[&dims](linalg::GenericOp op) {
[&dims](linalg::LinalgOp op) {
SmallVector<ReassociationIndices> reassociations;
reassociations.emplace_back(dims);
return reassociations;
Expand Down