-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Let memref.{expand,collapse}_shape
implement ReifyRankedShapedTypeOpInterface
#89111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Benoit Jacob (bjacob) ChangesIn #88423 I came across a need for folding It would be good to carry on the FYI @qedawkins @MaheshRavishankar @Shukla-Gaurav Full diff: https://github.com/llvm/llvm-project/pull/89111.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..8f6bff5809ca2b 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1546,8 +1546,11 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
//===----------------------------------------------------------------------===//
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
- MemRef_Op<mnemonic, !listconcat(traits,
- [Pure, ViewLikeOpInterface])>,
+ MemRef_Op<mnemonic, !listconcat(traits, [
+ Pure,
+ ViewLikeOpInterface,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>
+ ])>,
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
Results<(outs AnyStridedMemRef:$result)>{
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..e2ce7d93d227cc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -2079,6 +2080,95 @@ void ExpandShapeOp::getAsmResultNames(
setNameFn(getResult(), "expand_shape");
}
+LogicalResult ExpandShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ SmallVector<OpFoldResult> resultDims;
+ ArrayRef<int64_t> expandedShape = this->getResultType().getShape();
+ for (size_t expanded_dim = 0; expanded_dim < expandedShape.size();
+ ++expanded_dim) {
+ if (ShapedType::isDynamic(expandedShape[expanded_dim])) {
+ // Dynamic dimension case. Map expanded_dim to the corresponded
+ // collapsed dim. All other expanded dimensions corresponding to
+ // that collapsed dim must be static-size. Compute their product
+ // to divide the result size by.
+ auto reassoc = this->getReassociationIndices();
+ for (size_t collapsed_dim = 0; collapsed_dim < reassoc.size();
+ ++collapsed_dim) {
+ ReassociationIndices associated_dims = reassoc[collapsed_dim];
+ bool found_expanded_dim = false;
+ int64_t other_associated_dims_product_size = 1;
+ for (size_t associated_dim : associated_dims) {
+ if (associated_dim == expanded_dim) {
+ found_expanded_dim = true;
+ } else {
+ assert(!ShapedType::isDynamic(expandedShape[associated_dim]) &&
+ "At most one dimension of a reassociation group may be "
+ "dynamic in the result type.");
+ other_associated_dims_product_size *= expandedShape[associated_dim];
+ }
+ }
+ if (!found_expanded_dim) {
+ continue;
+ }
+ Value srcDimSize =
+ builder.create<memref::DimOp>(getLoc(), getSrc(), collapsed_dim);
+ Value resultDimSize = builder.create<arith::DivSIOp>(
+ getLoc(), srcDimSize,
+ builder.create<arith::ConstantIndexOp>(
+ getLoc(), other_associated_dims_product_size));
+ resultDims.push_back(resultDimSize);
+ }
+ } else {
+ resultDims.push_back(getAsIndexOpFoldResult(builder.getContext(),
+ expandedShape[expanded_dim]));
+ }
+ }
+ reifiedReturnShapes = {resultDims};
+ return success();
+}
+
+LogicalResult CollapseShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ SmallVector<OpFoldResult> resultDims;
+ ArrayRef<int64_t> collapsedShape = this->getResultType().getShape();
+ ArrayRef<int64_t> expandedShape = this->getSrcType().getShape();
+ for (size_t collapsed_dim = 0; collapsed_dim < collapsedShape.size();
+ ++collapsed_dim) {
+ if (ShapedType::isDynamic(collapsedShape[collapsed_dim])) {
+ // Dynamic dimension case. All other expanded dimensions corresponding
+ // to that collapsed_dim must be static-size. Compute their product
+ // to multiply the result size by.
+ auto reassoc = this->getReassociationIndices();
+ ReassociationIndices associated_dims = reassoc[collapsed_dim];
+ std::optional<size_t> expanded_dim;
+ int64_t other_associated_dims_product_size = 1;
+ for (size_t associated_dim : associated_dims) {
+ if (ShapedType::isDynamic(expandedShape[associated_dim])) {
+ assert(!expanded_dim && "At most one dimension of a reassociation "
+ "group may be dynamic in the result type.");
+ expanded_dim = associated_dim;
+ } else {
+ other_associated_dims_product_size *= expandedShape[associated_dim];
+ }
+ }
+ assert(expanded_dim && "No dynamic dimension in the reassociation group "
+ "to match the dynamic collapsed dimension.");
+ Value srcDimSize =
+ builder.create<memref::DimOp>(getLoc(), getSrc(), *expanded_dim);
+ Value resultDimSize = builder.create<arith::MulIOp>(
+ getLoc(), srcDimSize,
+ builder.create<arith::ConstantIndexOp>(
+ getLoc(), other_associated_dims_product_size));
+ resultDims.push_back(resultDimSize);
+ } else {
+ resultDims.push_back(getAsIndexOpFoldResult(
+ builder.getContext(), collapsedShape[collapsed_dim]));
+ }
+ }
+ reifiedReturnShapes = {resultDims};
+ return success();
+}
+
/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
/// result and operand. Layout maps are verified separately.
///
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 18e9a9d02e1081..fb0f9106e61bbf 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -25,3 +25,35 @@ func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
%0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
return %0 : index
}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 0
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+ -> index {
+ %c1 = arith.constant 1 : index
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]]: memref<?x8xi32> into memref<1x?x2x4xi32>
+ %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.collapse_shape)
+// CHECK-LABEL: func @dim_of_memref_collapse_shape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<1x?x2x4xi32>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 1
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<1x?x2x4xi32>
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_memref_collapse_shape(%arg0: memref<1x?x2x4xi32>)
+ -> index {
+ %c0 = arith.constant 0 : index
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2, 3]]: memref<1x?x2x4xi32> into memref<?x8xi32>
+ %1 = memref.dim %0, %c0 : memref<?x8xi32>
+ return %1 : index
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a direct test for type inference?
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { | ||
SmallVector<OpFoldResult> resultDims; | ||
ArrayRef<int64_t> expandedShape = this->getResultType().getShape(); | ||
for (size_t expanded_dim = 0; expanded_dim < expandedShape.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
} else { | ||
resultDims.push_back(getAsIndexOpFoldResult(builder.getContext(), | ||
expandedShape[expanded_dim])); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a good candidate for https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code.
SmallVector<OpFoldResult> resultDims; | ||
ArrayRef<int64_t> expandedShape = this->getResultType().getShape(); | ||
for (size_t expanded_dim = 0; expanded_dim < expandedShape.size(); | ||
++expanded_dim) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MLIR uses camelCase for variables, here and below.
// collapsed dim. All other expanded dimensions corresponding to | ||
// that collapsed dim must be static-size. Compute their product | ||
// to divide the result size by. | ||
auto reassoc = this->getReassociationIndices(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please expand auto
unless the type is obvious from line-level context.
} | ||
Value srcDimSize = | ||
builder.create<memref::DimOp>(getLoc(), getSrc(), collapsed_dim); | ||
Value resultDimSize = builder.create<arith::DivSIOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather use an unsigned division here, it lowers to simpler code on most targets, and sizes are unsigned.
Thanks for the drive-by review. Actually I am rebasing on top of #69267 , which will make this PR trivial as now the output shape is encoded on the op already. The code simplifies to LogicalResult ExpandShapeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
reifiedResultShapes = {
getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
return success();
} |
In #88423 I came across a need for folding
memref.dim
intomemref.expand_shape
and it was (IMO rightly) suggested that the proper way to fix that was to implementReifyRankedShapedTypeOpInterface
. This PR does that for bothmemref.{expand,collapse}_shape
to be consistent, as it would be surprising if these two ops weren't closely mirroring one another.It would be good to carry on the
ReifyRankedShapedTypeOpInterface
migration, in particular completing it formemref
andtensor
ops, particularly totensor.{expand,collapse}_shape
to be consistent with this (and then one could drop some existing custom Fold patterns). However, I have heard of an ongoing project to generalizeexpand_shape
to relax the requirement that at most one dimension in each reassociation group be dynamic. It would be wise to allow for that project to complete first, as theReifyRankedShapedTypeOpInterface
implementation will otherwise entrench the current semantics.FYI @qedawkins @MaheshRavishankar @Shukla-Gaurav