Skip to content

[mlir][tensor] Rewrite tensor.pack as a constant #93954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 154 additions & 1 deletion mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
//
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Threading.h"

using namespace mlir;
using namespace mlir::tensor;
Expand Down Expand Up @@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
}
};

/// Rewrite tensor.pack with arith.constant if the pack is writing
/// to an empty tensor and the destination shape is static.
struct PackToConstant : OpRewritePattern<tensor::PackOp> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isnt the expectation that these patterns come with a way to control when these patterns get applied?

using OpRewritePattern<tensor::PackOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::PackOp packOp,
PatternRewriter &rewriter) const override {
auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
if (!constOp)
return failure();
// Must be a dense constant.
auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
if (!denseAttr)
return failure();

// Bail out if the pack is used as a writing operation i.e.,
// the destination is not a tensor.empty.
if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
return rewriter.notifyMatchFailure(packOp,
"expects empty tensor destination");
// Pack destination must have static shape.
if (!packOp.getDestType().hasStaticShape())
return rewriter.notifyMatchFailure(
packOp, "expects destination with static shape");
Comment on lines +66 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can drop the comments because the failure message already document it for us.


// Pack with padding is not supported currently.
// TODO: Insert padding values as a part of rewrite.
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: perhaps say that it is NIY (not implemented yet) in the failure message.


OpBuilder::InsertionGuard guard(rewriter);

// If it is a splat constant, rewrite the pack directly.
if (denseAttr.isSplat()) {
DenseElementsAttr packedDenseShape =
denseAttr.reshape(packOp.getDestType());
rewriter.setInsertionPoint(constOp);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);

return success();
}
Comment on lines +83 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is already covered in folders.

OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
std::optional<Attribute> paddingValue;
if (auto pad = adaptor.getPaddingValue())
paddingValue = pad;
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getDestType(), paddingValue))
return reshapedSource;
return {};
}


// Constant contains non-splat dense values.
// Move the data into a new packed buffer. Each value is placed into its new
// position as defined by the pack operation.
ArrayRef<char> srcRawData = denseAttr.getRawData();
SmallVector<char> destRawData(srcRawData.size());

int64_t numberOfElements = denseAttr.getNumElements();
SmallVector<int64_t> strides =
computeStrides(packOp.getDestType().getShape());

// Parallelize raw data movement to speedup large constant packing.
parallelFor(
packOp.getContext(), 0, numberOfElements,
[&](size_t destLinearizedIdx) {
// Step 1: De-linearize destination index.
// f(lin) = tmp[A][B][C]
SmallVector<int64_t> destIndices =
delinearize(destLinearizedIdx, strides);

// Step 2: Arrange the indexes based on the packing information.
// Compute inverse of outerDimsPerm to bring the loops into the
// canonical form tmp[A][B][a][b].
if (!packOp.getOuterDimsPerm().empty()) {
SmallVector<int64_t> inversePermutation =
invertPermutationVector(packOp.getOuterDimsPerm());
Comment on lines +116 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move it before where it is used, i.e., line 120.

SmallVector<int64_t> tileLoops;
for (int64_t i = 0; i < packOp.getSourceType().getRank(); i++)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can use packOp.getSourceRank().

tileLoops.push_back(destIndices[i]);
applyPermutationToVector(tileLoops, inversePermutation);

SmallVector<int64_t> pointLoops;
for (size_t i = packOp.getSourceType().getRank();
i < destIndices.size(); i++) {
pointLoops.push_back(destIndices[i]);
}
Comment on lines +123 to +127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simplify it by using llvm::to_vector + llvm::seq<int64_t>. Also we can use getSourceRank().

Suggested change
SmallVector<int64_t> pointLoops;
for (size_t i = packOp.getSourceType().getRank();
i < destIndices.size(); i++) {
pointLoops.push_back(destIndices[i]);
}
SmallVector<int64_t> pointLoops = llvm::to_vector(llvm::seq<int64_t>(packOp.getSourceRank(), destIndices.size());

[optional] using packOp.getDestRank() instead of destIndices.size() is slightly better to me. Because it directly ties to the pack op, so people don't need to look at what destIndices is.


destIndices = tileLoops;
destIndices.append(pointLoops.begin(), pointLoops.end());
Comment on lines +129 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reviewing the changes, can you try if getPackInverseDestPerm + applyPermutationToVector works?

SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp);

/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation
/// vector `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a',
/// 'b']`.
template <typename T, unsigned N>
void applyPermutationToVector(SmallVector<T, N> &inVec,
ArrayRef<int64_t> permutation) {
inVec = applyPermutation(inVec, permutation);
}

}
assert(destIndices.size() ==
static_cast<size_t>(packOp.getDestType().getRank()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use getDestRank() method.


// After interchanging the outermost tiled loop we end up in the
// canonical form tmp[A][B][a][b]. Squash the point loops with the
// tiled ones.
llvm::DenseSet<int64_t> tiledLoops(packOp.getInnerDimsPos().begin(),
packOp.getInnerDimsPos().end());
llvm::DenseMap<int64_t, int64_t> mappingTileToPointLoops;
// Map the position of the tiled loops with the point one.
// For example:
// [A][B] -> [A][B][a][b]
// entry: [A : 0] [a : 2]
// entry: [B : 1] [b : 3]
// [A][B] -> [A][B][b]
// entry: [B : 1] [b : 2]
for (auto [idx, tileLoop] : llvm::enumerate(packOp.getInnerDimsPos()))
mappingTileToPointLoops[tileLoop] = idx;
Comment on lines +141 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment is off? Do we have entries for inner dims? All the values in getInnderDimsPos are less than the rank of source, so we won't have [a : 2], [b : 3] and [b : 2] entries. Do I misunderstand something?


SmallVector<int64_t> srcIndices;
SmallVector<int64_t> tilesSizes = packOp.getStaticTiles();
int64_t numberOfTileLoops = packOp.getSourceType().getRank();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: getSourceRank()

size_t tilePosIdx = 0;
for (int64_t i = 0; i < numberOfTileLoops; i++) {
if (!tiledLoops.count(i)) {
// Loop is not tiled.
srcIndices.push_back(destIndices[i]);
} else {
// Loop is tiled, account for the point loop distance.
srcIndices.push_back(
destIndices[i] * tilesSizes[tilePosIdx] +
destIndices[numberOfTileLoops + mappingTileToPointLoops[i]]);
tilePosIdx++;
}
}
assert(srcIndices.size() == static_cast<size_t>(numberOfTileLoops));

int64_t srcLinearizedIdx = linearize(
srcIndices, computeStrides(packOp.getSourceType().getShape()));
assert(srcLinearizedIdx < numberOfElements);

// Step 3: Do the packing.
// Copy the source element byte-wise to its packed destination
// position.
size_t elementByteSize =
denseAttr.getRawData().size() / denseAttr.getNumElements();
for (size_t i = 0; i < elementByteSize; i++) {
destRawData[destLinearizedIdx * elementByteSize + i] =
srcRawData[srcLinearizedIdx * elementByteSize + i];
}
});

// Fail gracefully if something went wrong.
bool detectSpalt = false;
if (!DenseElementsAttr::isValidRawBuffer(packOp.getDestType(), destRawData,
detectSpalt))
return rewriter.notifyMatchFailure(
packOp, "failed to create packed raw data buffer");

// Replace the pack with a new constant.
auto packedDenseShape =
DenseElementsAttr::getFromRawBuffer(packOp.getDestType(), destRawData);
rewriter.setInsertionPoint(constOp);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);

return success();
}
};

} // namespace

void mlir::tensor::populateRewriteAsConstantPatterns(
RewritePatternSet &patterns) {
patterns.add<GenerateToConstant>(patterns.getContext());
patterns.add<GenerateToConstant, PackToConstant>(patterns.getContext());
}
Loading
Loading