Skip to content

Let memref.{expand,collapse}_shape implement ReifyRankedShapedTypeOpInterface #89111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Apr 17, 2024

In #88423 I came across a need for folding memref.dim into memref.expand_shape and it was (IMO rightly) suggested that the proper way to fix that was to implement ReifyRankedShapedTypeOpInterface. This PR does that for both memref.{expand,collapse}_shape to be consistent, as it would be surprising if these two ops weren't closely mirroring one another.

It would be good to carry on the ReifyRankedShapedTypeOpInterface migration, in particular completing it for memref and tensor ops, particularly to tensor.{expand,collapse}_shape to be consistent with this (and then one could drop some existing custom Fold patterns). However, I have heard of an ongoing project to generalize expand_shape to relax the requirement that at most one dimension in each reassociation group be dynamic. It would be wise to allow for that project to complete first, as the ReifyRankedShapedTypeOpInterface implementation will otherwise entrench the current semantics.

FYI @qedawkins @MaheshRavishankar @Shukla-Gaurav

@bjacob bjacob marked this pull request as ready for review April 17, 2024 17:54
@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

Changes

In #88423 I came across a need for folding memref.dim into memref.expand_shape and it was (IMO rightly) suggested that the proper way to fix that was to implement ReifyRankedShapedTypeOpInterface. This PR does that for both memref.{expand,collapse}_shape to be consistent, as it would be surprising if these two ops weren't closely mirroring one another.

It would be good to carry on the ReifyRankedShapedTypeOpInterface migration, in particular completing it for memref and tensor ops, particularly to tensor.{expand,collapse}_shape to be consistent with this (and then one could drop some existing custom Fold patterns). However, I have heard of an ongoing project to generalize expand_shape to relax the requirement that at most one dimension in each reassociation group be dynamic. It would be wise to allow for that project to complete first, as the ReifyRankedShapedTypeOpInterface implementation will otherwise entrench the current semantics.

FYI @qedawkins @MaheshRavishankar @Shukla-Gaurav


Full diff: https://github.com/llvm/llvm-project/pull/89111.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+5-2)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+90)
  • (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+32)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..8f6bff5809ca2b 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1546,8 +1546,11 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
 //===----------------------------------------------------------------------===//
 
 class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
-    MemRef_Op<mnemonic, !listconcat(traits,
-      [Pure, ViewLikeOpInterface])>,
+    MemRef_Op<mnemonic, !listconcat(traits, [
+      Pure,
+      ViewLikeOpInterface,
+      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>
+    ])>,
     Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)>{
 
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..e2ce7d93d227cc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -2079,6 +2080,95 @@ void ExpandShapeOp::getAsmResultNames(
   setNameFn(getResult(), "expand_shape");
 }
 
+LogicalResult ExpandShapeOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  SmallVector<OpFoldResult> resultDims;
+  ArrayRef<int64_t> expandedShape = this->getResultType().getShape();
+  for (size_t expanded_dim = 0; expanded_dim < expandedShape.size();
+       ++expanded_dim) {
+    if (ShapedType::isDynamic(expandedShape[expanded_dim])) {
+      // Dynamic dimension case. Map expanded_dim to the corresponded
+      // collapsed dim. All other expanded dimensions corresponding to
+      // that collapsed dim must be static-size. Compute their product
+      // to divide the result size by.
+      auto reassoc = this->getReassociationIndices();
+      for (size_t collapsed_dim = 0; collapsed_dim < reassoc.size();
+           ++collapsed_dim) {
+        ReassociationIndices associated_dims = reassoc[collapsed_dim];
+        bool found_expanded_dim = false;
+        int64_t other_associated_dims_product_size = 1;
+        for (size_t associated_dim : associated_dims) {
+          if (associated_dim == expanded_dim) {
+            found_expanded_dim = true;
+          } else {
+            assert(!ShapedType::isDynamic(expandedShape[associated_dim]) &&
+                   "At most one dimension of a reassociation group may be "
+                   "dynamic in the result type.");
+            other_associated_dims_product_size *= expandedShape[associated_dim];
+          }
+        }
+        if (!found_expanded_dim) {
+          continue;
+        }
+        Value srcDimSize =
+            builder.create<memref::DimOp>(getLoc(), getSrc(), collapsed_dim);
+        Value resultDimSize = builder.create<arith::DivSIOp>(
+            getLoc(), srcDimSize,
+            builder.create<arith::ConstantIndexOp>(
+                getLoc(), other_associated_dims_product_size));
+        resultDims.push_back(resultDimSize);
+      }
+    } else {
+      resultDims.push_back(getAsIndexOpFoldResult(builder.getContext(),
+                                                  expandedShape[expanded_dim]));
+    }
+  }
+  reifiedReturnShapes = {resultDims};
+  return success();
+}
+
+LogicalResult CollapseShapeOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  SmallVector<OpFoldResult> resultDims;
+  ArrayRef<int64_t> collapsedShape = this->getResultType().getShape();
+  ArrayRef<int64_t> expandedShape = this->getSrcType().getShape();
+  for (size_t collapsed_dim = 0; collapsed_dim < collapsedShape.size();
+       ++collapsed_dim) {
+    if (ShapedType::isDynamic(collapsedShape[collapsed_dim])) {
+      // Dynamic dimension case. All other expanded dimensions corresponding
+      // to that collapsed_dim must be static-size. Compute their product
+      // to multiply the result size by.
+      auto reassoc = this->getReassociationIndices();
+      ReassociationIndices associated_dims = reassoc[collapsed_dim];
+      std::optional<size_t> expanded_dim;
+      int64_t other_associated_dims_product_size = 1;
+      for (size_t associated_dim : associated_dims) {
+        if (ShapedType::isDynamic(expandedShape[associated_dim])) {
+          assert(!expanded_dim && "At most one dimension of a reassociation "
+                                  "group may be dynamic in the result type.");
+          expanded_dim = associated_dim;
+        } else {
+          other_associated_dims_product_size *= expandedShape[associated_dim];
+        }
+      }
+      assert(expanded_dim && "No dynamic dimension in the reassociation group "
+                             "to match the dynamic collapsed dimension.");
+      Value srcDimSize =
+          builder.create<memref::DimOp>(getLoc(), getSrc(), *expanded_dim);
+      Value resultDimSize = builder.create<arith::MulIOp>(
+          getLoc(), srcDimSize,
+          builder.create<arith::ConstantIndexOp>(
+              getLoc(), other_associated_dims_product_size));
+      resultDims.push_back(resultDimSize);
+    } else {
+      resultDims.push_back(getAsIndexOpFoldResult(
+          builder.getContext(), collapsedShape[collapsed_dim]));
+    }
+  }
+  reifiedReturnShapes = {resultDims};
+  return success();
+}
+
 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
 /// result and operand. Layout maps are verified separately.
 ///
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 18e9a9d02e1081..fb0f9106e61bbf 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -25,3 +25,35 @@ func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
   %0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
   return %0 : index
 }
+
+// -----
+
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 0
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+    -> index {
+  %c1 = arith.constant 1 : index
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]]: memref<?x8xi32> into memref<1x?x2x4xi32>
+  %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.collapse_shape)
+// CHECK-LABEL: func @dim_of_memref_collapse_shape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<1x?x2x4xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 1
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<1x?x2x4xi32>
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_collapse_shape(%arg0: memref<1x?x2x4xi32>)
+    -> index {
+  %c0 = arith.constant 0 : index
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2, 3]]: memref<1x?x2x4xi32> into memref<?x8xi32>
+  %1 = memref.dim %0, %c0 : memref<?x8xi32>
+  return %1 : index
+}

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a direct test for type inference?

OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
SmallVector<OpFoldResult> resultDims;
ArrayRef<int64_t> expandedShape = this->getResultType().getShape();
for (size_t expanded_dim = 0; expanded_dim < expandedShape.size();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +2121 to +2124
} else {
resultDims.push_back(getAsIndexOpFoldResult(builder.getContext(),
expandedShape[expanded_dim]));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SmallVector<OpFoldResult> resultDims;
ArrayRef<int64_t> expandedShape = this->getResultType().getShape();
for (size_t expanded_dim = 0; expanded_dim < expandedShape.size();
++expanded_dim) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR uses camelCase for variables, here and below.

// collapsed dim. All other expanded dimensions corresponding to
// that collapsed dim must be static-size. Compute their product
// to divide the result size by.
auto reassoc = this->getReassociationIndices();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please expand auto unless the type is obvious from line-level context.

}
Value srcDimSize =
builder.create<memref::DimOp>(getLoc(), getSrc(), collapsed_dim);
Value resultDimSize = builder.create<arith::DivSIOp>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather use an unsigned division here, it lowers to simpler code on most targets, and sizes are unsigned.

@bjacob
Copy link
Contributor Author

bjacob commented Apr 18, 2024

Thanks for the drive-by review. Actually I am rebasing on top of #69267 , which will make this PR trivial as now the output shape is encoded on the op already. The code simplifies to

LogicalResult ExpandShapeOp::reifyResultShapes(
    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
  reifiedResultShapes = {
      getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
  return success();
}

@bjacob bjacob closed this Apr 18, 2024
bjacob added a commit that referenced this pull request May 3, 2024
#90975)

This is a new take on #89111. Now that #90040 is merged, this has become
trivial to implement. The added test shows the kind of benefit that we
get from this: now dim-of-expand-shape naturally folds without us
needing to implement an ad-hoc folding rewrite.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants