Skip to content

Let memref.{expand,collapse}_shape implement ReifyRankedShapedTypeOpInterface #89111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1546,8 +1546,11 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
//===----------------------------------------------------------------------===//

class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
MemRef_Op<mnemonic, !listconcat(traits,
[Pure, ViewLikeOpInterface])>,
MemRef_Op<mnemonic, !listconcat(traits, [
Pure,
ViewLikeOpInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>
])>,
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
Results<(outs AnyStridedMemRef:$result)>{

Expand Down
90 changes: 90 additions & 0 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -2079,6 +2080,95 @@ void ExpandShapeOp::getAsmResultNames(
setNameFn(getResult(), "expand_shape");
}

LogicalResult ExpandShapeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
SmallVector<OpFoldResult> resultDims;
ArrayRef<int64_t> expandedShape = this->getResultType().getShape();
for (size_t expanded_dim = 0; expanded_dim < expandedShape.size();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++expanded_dim) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR uses camelCase for variables, here and below.

if (ShapedType::isDynamic(expandedShape[expanded_dim])) {
// Dynamic dimension case. Map expanded_dim to the corresponded
// collapsed dim. All other expanded dimensions corresponding to
// that collapsed dim must be static-size. Compute their product
// to divide the result size by.
auto reassoc = this->getReassociationIndices();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please expand auto unless the type is obvious from line-level context.

for (size_t collapsed_dim = 0; collapsed_dim < reassoc.size();
++collapsed_dim) {
ReassociationIndices associated_dims = reassoc[collapsed_dim];
bool found_expanded_dim = false;
int64_t other_associated_dims_product_size = 1;
for (size_t associated_dim : associated_dims) {
if (associated_dim == expanded_dim) {
found_expanded_dim = true;
} else {
assert(!ShapedType::isDynamic(expandedShape[associated_dim]) &&
"At most one dimension of a reassociation group may be "
"dynamic in the result type.");
other_associated_dims_product_size *= expandedShape[associated_dim];
}
}
if (!found_expanded_dim) {
continue;
}
Value srcDimSize =
builder.create<memref::DimOp>(getLoc(), getSrc(), collapsed_dim);
Value resultDimSize = builder.create<arith::DivSIOp>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather use an unsigned division here, it lowers to simpler code on most targets, and sizes are unsigned.

getLoc(), srcDimSize,
builder.create<arith::ConstantIndexOp>(
getLoc(), other_associated_dims_product_size));
resultDims.push_back(resultDimSize);
}
} else {
resultDims.push_back(getAsIndexOpFoldResult(builder.getContext(),
expandedShape[expanded_dim]));
}
Comment on lines +2121 to +2124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
reifiedReturnShapes = {resultDims};
return success();
}

LogicalResult CollapseShapeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
SmallVector<OpFoldResult> resultDims;
ArrayRef<int64_t> collapsedShape = this->getResultType().getShape();
ArrayRef<int64_t> expandedShape = this->getSrcType().getShape();
for (size_t collapsed_dim = 0; collapsed_dim < collapsedShape.size();
++collapsed_dim) {
if (ShapedType::isDynamic(collapsedShape[collapsed_dim])) {
// Dynamic dimension case. All other expanded dimensions corresponding
// to that collapsed_dim must be static-size. Compute their product
// to multiply the result size by.
auto reassoc = this->getReassociationIndices();
ReassociationIndices associated_dims = reassoc[collapsed_dim];
std::optional<size_t> expanded_dim;
int64_t other_associated_dims_product_size = 1;
for (size_t associated_dim : associated_dims) {
if (ShapedType::isDynamic(expandedShape[associated_dim])) {
assert(!expanded_dim && "At most one dimension of a reassociation "
"group may be dynamic in the result type.");
expanded_dim = associated_dim;
} else {
other_associated_dims_product_size *= expandedShape[associated_dim];
}
}
assert(expanded_dim && "No dynamic dimension in the reassociation group "
"to match the dynamic collapsed dimension.");
Value srcDimSize =
builder.create<memref::DimOp>(getLoc(), getSrc(), *expanded_dim);
Value resultDimSize = builder.create<arith::MulIOp>(
getLoc(), srcDimSize,
builder.create<arith::ConstantIndexOp>(
getLoc(), other_associated_dims_product_size));
resultDims.push_back(resultDimSize);
} else {
resultDims.push_back(getAsIndexOpFoldResult(
builder.getContext(), collapsedShape[collapsed_dim]));
}
}
reifiedReturnShapes = {resultDims};
return success();
}

/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
/// result and operand. Layout maps are verified separately.
///
Expand Down
32 changes: 32 additions & 0 deletions mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,35 @@ func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
%0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
return %0 : index
}

// -----

// Test case: Folding of memref.dim(memref.expand_shape)
// CHECK-LABEL: func @dim_of_memref_expand_shape(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 0
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
// CHECK: return %[[DIM]] : index
func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
-> index {
%c1 = arith.constant 1 : index
%0 = memref.expand_shape %arg0 [[0, 1], [2, 3]]: memref<?x8xi32> into memref<1x?x2x4xi32>
%1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
return %1 : index
}

// -----

// Test case: Folding of memref.dim(memref.collapse_shape)
// CHECK-LABEL: func @dim_of_memref_collapse_shape(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<1x?x2x4xi32>
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 1
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<1x?x2x4xi32>
// CHECK: return %[[DIM]] : index
func.func @dim_of_memref_collapse_shape(%arg0: memref<1x?x2x4xi32>)
-> index {
%c0 = arith.constant 0 : index
%0 = memref.collapse_shape %arg0 [[0, 1], [2, 3]]: memref<1x?x2x4xi32> into memref<?x8xi32>
%1 = memref.dim %0, %c0 : memref<?x8xi32>
return %1 : index
}