Skip to content

[mlir][MemRef] Use specialized index ops to fold expand/collapse_shape #138930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
according to specified offsets, sizes, and strides.

```mlir
%result1 = memref.reinterpret_cast %arg0 to
%result1 = memref.reinterpret_cast %arg0 to
offset: [9],
sizes: [4, 4],
strides: [16, 2]
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
memref<4x4xf32, strided<[16, 2], offset: 9>>

%result2 = memref.reinterpret_cast %result1 to
%result2 = memref.reinterpret_cast %result1 to
offset: [0],
sizes: [2, 2],
strides: [4, 2]
Expand Down Expand Up @@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
OpBuilder &b, Location loc, MemRefType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape);

// Return a vector with all the static and dynamic values in the output shape.
SmallVector<OpFoldResult> getMixedOutputShape() {
OpBuilder builder(getContext());
return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
}
}];

let hasVerifier = 1;
Expand Down Expand Up @@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
let summary = "store operation";
let description = [{
The `store` op stores an element into a memref at the specified indices.

The number of indices must match the rank of the memref. The indices must
be in-bounds: `0 <= idx < dim_size`

Expand Down Expand Up @@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
Unlike the `reinterpret_cast`, the values are relative to the strided
memref of the input (`%result1` in this case) and not its
underlying memory.

Example 2:

```mlir
Expand Down
169 changes: 49 additions & 120 deletions mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,92 +59,28 @@ using namespace mlir;
///
/// %2 = load %0[6 * i1 + i2, %i3] :
/// memref<12x42xf32>
static LogicalResult
resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
memref::ExpandShapeOp expandShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
// Record the rewriter context for constructing ops later.
MLIRContext *ctx = rewriter.getContext();

// Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
// This is done for the purpose of inferring the output shape via
// `inferExpandOutputShape` which will in turn be used for suffix product
// calculation later.
SmallVector<OpFoldResult> srcShape;
MemRefType srcType = expandShapeOp.getSrcType();

for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
if (srcType.isDynamicDim(i)) {
srcShape.push_back(
rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
.getResult());
} else {
srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
}
}

auto outputShape = inferExpandShapeOutputShape(
rewriter, loc, expandShapeOp.getResultType(),
expandShapeOp.getReassociationIndices(), srcShape);
if (!outputShape.has_value())
return failure();
static LogicalResult resolveSourceIndicesExpandShape(
Location loc, PatternRewriter &rewriter,
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();

// Traverse all reassociation groups to determine the appropriate indices
// corresponding to each one of them post op folding.
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
assert(!groups.empty() && "association indices groups cannot be empty");
// Flag to indicate the presence of dynamic dimensions in current
// reassociation group.
int64_t groupSize = groups.size();

// Group output dimensions utilized in this reassociation group for suffix
// product calculation.
SmallVector<OpFoldResult> sizesVal(groupSize);
for (int64_t i = 0; i < groupSize; ++i) {
sizesVal[i] = (*outputShape)[groups[i]];
for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
assert(!group.empty() && "association indices groups cannot be empty");
int64_t groupSize = group.size();
if (groupSize == 1) {
sourceIndices.push_back(indices[group[0]]);
continue;
}

// Calculate suffix product of relevant output dimension sizes.
SmallVector<OpFoldResult> suffixProduct =
memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);

// Create affine expression variables for dimensions and symbols in the
// newly constructed affine map.
SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
bindDimsList<AffineExpr>(ctx, dims);
bindSymbolsList<AffineExpr>(ctx, symbols);

// Linearize binded dimensions and symbols to construct the resultant
// affine expression for this indice.
AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);

// Record the load index corresponding to each dimension in the
// reassociation group. These are later supplied as operands to the affine
// map used for calulating relevant index post op folding.
SmallVector<OpFoldResult> dynamicIndices(groupSize);
for (int64_t i = 0; i < groupSize; i++)
dynamicIndices[i] = indices[groups[i]];

// Supply suffix product results followed by load op indices as operands
// to the map.
SmallVector<OpFoldResult> mapOperands;
llvm::append_range(mapOperands, suffixProduct);
llvm::append_range(mapOperands, dynamicIndices);

// Creating maximally folded and composed affine.apply composes better
// with other transformations without interleaving canonicalization
// passes.
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc,
AffineMap::get(/*numDims=*/groupSize,
/*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
mapOperands);

// Push index value in the op post folding corresponding to this
// reassociation group.
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
SmallVector<OpFoldResult> groupBasis =
llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
SmallVector<Value> groupIndices =
llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
sourceIndices.push_back(collapsedIndex);
}
return success();
}
Expand All @@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
memref::CollapseShapeOp collapseShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
int64_t cnt = 0;
SmallVector<OpFoldResult> dynamicIndices;
for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
assert(!groups.empty() && "association indices groups cannot be empty");
dynamicIndices.push_back(indices[cnt++]);
int64_t groupSize = groups.size();

// Calculate suffix product for all collapse op source dimension sizes
// except the most major one of each group.
// We allow the most major source dimension to be dynamic but enforce all
// others to be known statically.
SmallVector<int64_t> sizes(groupSize, 1);
for (int64_t i = 1; i < groupSize; ++i) {
sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
if (sizes[i] == ShapedType::kDynamic)
return failure();
}
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);

// Derive the index values along all dimensions of the source corresponding
// to the index wrt to collapsed shape op output.
auto d0 = rewriter.getAffineDimExpr(0);
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);

// Construct the AffineApplyOp for each delinearizingExpr.
for (int64_t i = 0; i < groupSize; i++) {
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc,
AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
delinearizingExprs[i]),
dynamicIndices);
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
MemRefType sourceType = collapseShapeOp.getSrcType();
// Note: collapse_shape requires a strided memref, we can do this.
auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
loc, collapseShapeOp.getSrc());
SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
for (auto [index, group] :
llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
assert(!group.empty() && "association indices groups cannot be empty");
int64_t groupSize = group.size();

if (groupSize == 1) {
sourceIndices.push_back(index);
continue;
}
dynamicIndices.clear();

SmallVector<OpFoldResult> basis =
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, index, basis, /*hasOuterBound=*/true);
llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
for (int64_t i = 0; i < srcRank; i++) {
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, zeroAffineMap, dynamicIndices);
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
}
Expand Down Expand Up @@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
// memref.load and affine.load guarantee that indexes start inbounds
// while the vector operations don't. This impacts if our linearization
// is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case([&](affine::AffineLoadOp op) {
Expand Down Expand Up @@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
// memref.store and affine.store guarantee that indexes start inbounds
// while the vector operations don't. This impacts if our linearization
// is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case([&](affine::AffineStoreOp op) {
Expand Down
Loading