Skip to content

Fold memref.dim into memref.expand_shape #88423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Apr 11, 2024

The lack of this folding pattern causes TypeConverter errors downstream (IREE) as memref.dim on memref.expand_shape cause non-1D memrefs to survive after we expect them to have been flattened.

This code is mostly copied from the corresponding TensorOps.cpp code, performing the corresponding folding of tensor.dim. The difference is that that code used a AffineApplyOp and we can't do that here, because that could create a dependency of MemRefDialect on AffineDialect, which would be circular as AffineDialect depends on MemRefDialect.

For the same reason, this PR only folds into expand_shape and not collapse_shape. Sorry about the dissymetry, it's because the folding code for collapse_shape made more involved use of AffineDialect so would have been more work to reimplement without AffineDialect, and for my own immediate purposes, expand_shape is enough.

@bjacob bjacob marked this pull request as ready for review April 11, 2024 19:10
@llvmbot
Copy link
Member

llvmbot commented Apr 11, 2024

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

Changes

The lack of this folding pattern causes TypeConverter errors downstream (IREE) as memref.dim on memref.expand_shape cause non-1D memrefs to survive after we expect them to have been flattened.

I this code is mostly copied from the corresponding TensorOps.cpp code, performing the corresponding folding of tensor.dim. The difference is that that code used a AffineApplyOp and we can't do that here, because that could create a dependency of MemRefDialect on AffineDialect, which would be circular as AffineDialect depends on MemRefDialect.

For the same reason, this PR only folds into expand_shape and not collapse_shape. Sorry about the dissymetry, it's because the folding code for collapse_shape made more involved use of AffineDialect so would have been more work to reimplement without AffineDialect, and for my own immediate purposes, expand_shape is enough.


Full diff: https://github.com/llvm/llvm-project/pull/88423.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+57-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+16)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..edc055c3180f07 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1125,11 +1125,67 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
   }
 };
 
+int64_t getCorrespondingSourceDim(ExpandShapeOp expandShapeOp,
+                                  int64_t resultDim) {
+  assert(resultDim >= 0 &&
+         resultDim < expandShapeOp.getResultType().getRank() &&
+         "invalid resultDim");
+  for (const auto &it :
+       llvm::enumerate(expandShapeOp.getReassociationIndices()))
+    if (llvm::is_contained(it.value(), resultDim))
+      return it.index();
+  assert(false && "could not find reassociation group");
+  return 0;
+}
+
+struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
+    if (!expandShapeOp)
+      return failure();
+
+    // Only constant dimension values are supported.
+    std::optional<int64_t> dim = dimOp.getConstantIndex();
+    if (!dim.has_value())
+      return failure();
+
+    // Skip static dims. These are folded to constant ops.
+    MemRefType resultType = expandShapeOp.getResultType();
+    if (!resultType.isDynamicDim(*dim))
+      return failure();
+
+    // Find reassociation group that contains this result dimension.
+    int64_t srcDim = getCorrespondingSourceDim(expandShapeOp, *dim);
+
+    // `dim` is the only dynamic dimension in `group`. (Otherwise, the
+    // ExpandShapeOp would be ambiguous.)
+    int64_t product = 1;
+    ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
+    for (int64_t d : grp) {
+      if (d != dim) {
+        assert(!resultType.isDynamicDim(d) && "expected static dim");
+        product *= resultType.getDimSize(d);
+      }
+    }
+
+    // result dim size = src dim size / (product(other dims in reassoc group))
+    Value srcDimSz =
+        rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
+    rewriter.replaceOpWithNewOp<arith::FloorDivSIOp>(
+        dimOp, srcDimSz,
+        rewriter.create<arith::ConstantIndexOp>(dimOp.getLoc(), product));
+    return success();
+  }
+};
+
 } // namespace
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<DimOfMemRefReshape>(context);
+  results.add<DimOfMemRefReshape, FoldDimOfExpandShape>(context);
 }
 
 // ---------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 506ed1f1c10b10..584f4d0e7067aa 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -313,6 +313,22 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
 
 // -----
 
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 0
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+    -> index {
+  %c1 = arith.constant 1 : index
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]]: memref<?x8xi32> into memref<1x?x2x4xi32>
+  %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+  return %1 : index
+}
+
+// -----
+
 // Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
 // CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
 //  CHECK-SAME:   %[[MEM:[0-9a-z]+]]: memref<*xf32>,

@MaheshRavishankar
Copy link
Contributor

Actually I am really surprised that the memref.expand/collapse_shape and the tensor.expand/collapse_shape ops do not implement the ReifyRankedShapedTypeOpInterface . They really should be and dim folding shouldnt be part of the canonicalization. If you make the operation implement the interface you can then just add the memref::populateResolveRankedShapedTypeResultDimsPatterns wherever you need to fold the dims away.

If you want a separate pass you can just run resolve-ranked-shaped-type-result-dims pass to run the patterns as a separate pass.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldnt be a canonicalization and should be implemented through using the interface. See other comment for details.

@joker-eph
Copy link
Collaborator

This shouldnt be a canonicalization

Can you expand why?

@bjacob
Copy link
Contributor Author

bjacob commented Apr 12, 2024

I can see how that makes sense. It means that for each such op, we just teach the compiler about computing its result shapes (by implementing reifyResultShapes) and it figures out how to do things such as folding dim ops, without the need for handwritten pattern rewriter code.

I should have the patch for memref.expand_shape today. But it seems that this treatment will be needed in more places, and I'm not sure I can commit to take care of them.

@bjacob
Copy link
Contributor Author

bjacob commented Apr 17, 2024

Superseded by #89111.

@bjacob bjacob closed this Apr 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants