-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Let memref.{expand,collapse}_shape
implement ReifyRankedShapedTypeOpInterface
#89111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Arith/Utils/Utils.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" | ||
#include "mlir/Dialect/Utils/StaticValueUtils.h" | ||
#include "mlir/IR/AffineMap.h" | ||
#include "mlir/IR/Builders.h" | ||
|
@@ -2079,6 +2080,95 @@ void ExpandShapeOp::getAsmResultNames( | |
setNameFn(getResult(), "expand_shape"); | ||
} | ||
|
||
LogicalResult ExpandShapeOp::reifyResultShapes( | ||
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { | ||
SmallVector<OpFoldResult> resultDims; | ||
ArrayRef<int64_t> expandedShape = this->getResultType().getShape(); | ||
for (size_t expanded_dim = 0; expanded_dim < expandedShape.size(); | ||
++expanded_dim) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MLIR uses camelCase for variables, here and below. |
||
if (ShapedType::isDynamic(expandedShape[expanded_dim])) { | ||
// Dynamic dimension case. Map expanded_dim to the corresponded | ||
// collapsed dim. All other expanded dimensions corresponding to | ||
// that collapsed dim must be static-size. Compute their product | ||
// to divide the result size by. | ||
auto reassoc = this->getReassociationIndices(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please expand |
||
for (size_t collapsed_dim = 0; collapsed_dim < reassoc.size(); | ||
++collapsed_dim) { | ||
ReassociationIndices associated_dims = reassoc[collapsed_dim]; | ||
bool found_expanded_dim = false; | ||
int64_t other_associated_dims_product_size = 1; | ||
for (size_t associated_dim : associated_dims) { | ||
if (associated_dim == expanded_dim) { | ||
found_expanded_dim = true; | ||
} else { | ||
assert(!ShapedType::isDynamic(expandedShape[associated_dim]) && | ||
"At most one dimension of a reassociation group may be " | ||
"dynamic in the result type."); | ||
other_associated_dims_product_size *= expandedShape[associated_dim]; | ||
} | ||
} | ||
if (!found_expanded_dim) { | ||
continue; | ||
} | ||
Value srcDimSize = | ||
builder.create<memref::DimOp>(getLoc(), getSrc(), collapsed_dim); | ||
Value resultDimSize = builder.create<arith::DivSIOp>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather use an unsigned division here, it lowers to simpler code on most targets, and sizes are unsigned. |
||
getLoc(), srcDimSize, | ||
builder.create<arith::ConstantIndexOp>( | ||
getLoc(), other_associated_dims_product_size)); | ||
resultDims.push_back(resultDimSize); | ||
} | ||
} else { | ||
resultDims.push_back(getAsIndexOpFoldResult(builder.getContext(), | ||
expandedShape[expanded_dim])); | ||
} | ||
Comment on lines
+2121
to
+2124
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like a good candidate for https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code. |
||
} | ||
reifiedReturnShapes = {resultDims}; | ||
return success(); | ||
} | ||
|
||
LogicalResult CollapseShapeOp::reifyResultShapes( | ||
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { | ||
SmallVector<OpFoldResult> resultDims; | ||
ArrayRef<int64_t> collapsedShape = this->getResultType().getShape(); | ||
ArrayRef<int64_t> expandedShape = this->getSrcType().getShape(); | ||
for (size_t collapsed_dim = 0; collapsed_dim < collapsedShape.size(); | ||
++collapsed_dim) { | ||
if (ShapedType::isDynamic(collapsedShape[collapsed_dim])) { | ||
// Dynamic dimension case. All other expanded dimensions corresponding | ||
// to that collapsed_dim must be static-size. Compute their product | ||
// to multiply the result size by. | ||
auto reassoc = this->getReassociationIndices(); | ||
ReassociationIndices associated_dims = reassoc[collapsed_dim]; | ||
std::optional<size_t> expanded_dim; | ||
int64_t other_associated_dims_product_size = 1; | ||
for (size_t associated_dim : associated_dims) { | ||
if (ShapedType::isDynamic(expandedShape[associated_dim])) { | ||
assert(!expanded_dim && "At most one dimension of a reassociation " | ||
"group may be dynamic in the result type."); | ||
expanded_dim = associated_dim; | ||
} else { | ||
other_associated_dims_product_size *= expandedShape[associated_dim]; | ||
} | ||
} | ||
assert(expanded_dim && "No dynamic dimension in the reassociation group " | ||
"to match the dynamic collapsed dimension."); | ||
Value srcDimSize = | ||
builder.create<memref::DimOp>(getLoc(), getSrc(), *expanded_dim); | ||
Value resultDimSize = builder.create<arith::MulIOp>( | ||
getLoc(), srcDimSize, | ||
builder.create<arith::ConstantIndexOp>( | ||
getLoc(), other_associated_dims_product_size)); | ||
resultDims.push_back(resultDimSize); | ||
} else { | ||
resultDims.push_back(getAsIndexOpFoldResult( | ||
builder.getContext(), collapsedShape[collapsed_dim])); | ||
} | ||
} | ||
reifiedReturnShapes = {resultDims}; | ||
return success(); | ||
} | ||
|
||
/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp | ||
/// result and operand. Layout maps are verified separately. | ||
/// | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://llvm.org/docs/CodingStandards.html#don-t-evaluate-end-every-time-through-a-loop, here and below.