-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Fold memref.dim
into memref.expand_shape
#88423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Benoit Jacob (bjacob) ChangesThe lack of this folding pattern causes TypeConverter errors downstream (IREE) as I this code is mostly copied from the corresponding TensorOps.cpp code, performing the corresponding folding of For the same reason, this PR only folds into Full diff: https://github.com/llvm/llvm-project/pull/88423.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..edc055c3180f07 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1125,11 +1125,67 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
};
+int64_t getCorrespondingSourceDim(ExpandShapeOp expandShapeOp,
+ int64_t resultDim) {
+ assert(resultDim >= 0 &&
+ resultDim < expandShapeOp.getResultType().getRank() &&
+ "invalid resultDim");
+ for (const auto &it :
+ llvm::enumerate(expandShapeOp.getReassociationIndices()))
+ if (llvm::is_contained(it.value(), resultDim))
+ return it.index();
+ assert(false && "could not find reassociation group");
+ return 0;
+}
+
+struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
+ if (!expandShapeOp)
+ return failure();
+
+ // Only constant dimension values are supported.
+ std::optional<int64_t> dim = dimOp.getConstantIndex();
+ if (!dim.has_value())
+ return failure();
+
+ // Skip static dims. These are folded to constant ops.
+ MemRefType resultType = expandShapeOp.getResultType();
+ if (!resultType.isDynamicDim(*dim))
+ return failure();
+
+ // Find reassociation group that contains this result dimension.
+ int64_t srcDim = getCorrespondingSourceDim(expandShapeOp, *dim);
+
+ // `dim` is the only dynamic dimension in `group`. (Otherwise, the
+ // ExpandShapeOp would be ambiguous.)
+ int64_t product = 1;
+ ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
+ for (int64_t d : grp) {
+ if (d != dim) {
+ assert(!resultType.isDynamicDim(d) && "expected static dim");
+ product *= resultType.getDimSize(d);
+ }
+ }
+
+ // result dim size = src dim size / (product(other dims in reassoc group))
+ Value srcDimSz =
+ rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
+ rewriter.replaceOpWithNewOp<arith::FloorDivSIOp>(
+ dimOp, srcDimSz,
+ rewriter.create<arith::ConstantIndexOp>(dimOp.getLoc(), product));
+ return success();
+ }
+};
+
} // namespace
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfMemRefReshape>(context);
+ results.add<DimOfMemRefReshape, FoldDimOfExpandShape>(context);
}
// ---------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 506ed1f1c10b10..584f4d0e7067aa 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -313,6 +313,22 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
// -----
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 0
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+ -> index {
+ %c1 = arith.constant 1 : index
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]]: memref<?x8xi32> into memref<1x?x2x4xi32>
+ %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+ return %1 : index
+}
+
+// -----
+
// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
|
Actually I am really surprised that the If you want a separate pass you can just run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldnt be a canonicalization and should be implemented through using the interface. See other comment for details.
Can you expand why? |
I can see how that makes sense. It means that for each such op, we just teach the compiler about computing its result shapes (by implementing I should have the patch for |
Superseded by #89111. |
The lack of this folding pattern causes TypeConverter errors downstream (IREE) as
memref.dim
onmemref.expand_shape
cause non-1D memrefs to survive after we expect them to have been flattened.This code is mostly copied from the corresponding TensorOps.cpp code, performing the corresponding folding of
tensor.dim
. The difference is that that code used aAffineApplyOp
and we can't do that here, because that could create a dependency of MemRefDialect on AffineDialect, which would be circular as AffineDialect depends on MemRefDialect.For the same reason, this PR only folds into
expand_shape
and notcollapse_shape
. Sorry about the dissymetry, it's because the folding code forcollapse_shape
made more involved use of AffineDialect so would have been more work to reimplement without AffineDialect, and for my own immediate purposes,expand_shape
is enough.