-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][NFC] Return MemRefType in memref.subview return type inference functions #120024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir-vector Author: Tomás Longeri (tlongeri) ChangesFull diff: https://github.com/llvm/llvm-project/pull/120024.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index a0d8d34f38237a..4e31bb153c5e7e 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -2079,14 +2079,14 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
- static Type inferResultType(MemRefType sourceMemRefType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
- static Type inferResultType(MemRefType sourceMemRefType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ static MemRefType inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides);
+ static MemRefType inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<OpFoldResult> staticOffsets,
+ ArrayRef<OpFoldResult> staticSizes,
+ ArrayRef<OpFoldResult> staticStrides);
/// A rank-reducing result type can be inferred from the desired result
/// shape. Only the layout map is inferred.
@@ -2095,16 +2095,16 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
/// and the desired sizes. In case there are more "ones" among the sizes
/// than the difference in source/result rank, it is not clear which dims of
/// size one should be dropped.
- static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
- MemRefType sourceMemRefType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
- static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
- MemRefType sourceMemRefType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ static MemRefType inferRankReducedResultType(
+ ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides);
+ static MemRefType inferRankReducedResultType(
+ ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
+ ArrayRef<OpFoldResult> staticOffsets,
+ ArrayRef<OpFoldResult> staticSizes,
+ ArrayRef<OpFoldResult> staticStrides);
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 2219505c9b802f..12768f06fb1b0e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2697,10 +2697,10 @@ void SubViewOp::getAsmResultNames(
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
-Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides) {
+MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides) {
unsigned rank = sourceMemRefType.getRank();
(void)rank;
assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
@@ -2739,10 +2739,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
sourceMemRefType.getMemorySpace());
}
-Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
+MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2758,13 +2758,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
staticSizes, staticStrides);
}
-Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
- MemRefType sourceRankedTensorType,
- ArrayRef<int64_t> offsets,
- ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides) {
- auto inferredType = llvm::cast<MemRefType>(
- inferResultType(sourceRankedTensorType, offsets, sizes, strides));
+MemRefType SubViewOp::inferRankReducedResultType(
+ ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
+ ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> strides) {
+ MemRefType inferredType =
+ inferResultType(sourceRankedTensorType, offsets, sizes, strides);
assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
"expected ");
if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
@@ -2790,11 +2789,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
inferredType.getMemorySpace());
}
-Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
- MemRefType sourceRankedTensorType,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
+MemRefType SubViewOp::inferRankReducedResultType(
+ ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2821,8 +2819,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
- resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
- sourceMemRefType, staticOffsets, staticSizes, staticStrides));
+ resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
+ staticSizes, staticStrides);
}
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2987,8 +2985,8 @@ LogicalResult SubViewOp::verify() {
// Compute the expected result type, assuming that there are no rank
// reductions.
- auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
- baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
+ MemRefType expectedType = SubViewOp::inferResultType(
+ baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
// Verify all properties of a shaped type: rank, element type and dimension
// sizes. This takes into account potential rank reductions.
@@ -3070,8 +3068,8 @@ static MemRefType getCanonicalSubViewResultType(
MemRefType currentResultType, MemRefType currentSourceType,
MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
- auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
- sourceType, mixedOffsets, mixedSizes, mixedStrides));
+ MemRefType nonRankReducedType = SubViewOp::inferResultType(
+ sourceType, mixedOffsets, mixedSizes, mixedStrides);
FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
currentSourceType, currentResultType, mixedSizes);
if (failed(unusedDims))
@@ -3105,9 +3103,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
- auto targetType =
- llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
- targetShape, memrefType, offsets, sizes, strides));
+ MemRefType targetType = SubViewOp::inferRankReducedResultType(
+ targetShape, memrefType, offsets, sizes, strides);
return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
sizes, strides);
}
@@ -3251,11 +3248,11 @@ struct SubViewReturnTypeCanonicalizer {
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
// Infer a memref type without taking into account any rank reductions.
- auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
- mixedSizes, mixedStrides);
+ MemRefType resTy = SubViewOp::inferResultType(
+ op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
if (!resTy)
return {};
- MemRefType nonReducedType = cast<MemRefType>(resTy);
+ MemRefType nonReducedType = resTy;
// Directly return the non-rank reduced type if there are no dropped dims.
llvm::SmallBitVector droppedDims = op.getDroppedDims();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 1f06318cbd60e0..8ffea5a7839980 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -70,9 +70,9 @@ propagateSubViewOp(RewriterBase &rewriter,
UnrealizedConversionCastOp conversionOp, SubViewOp op) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
- auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
+ MemRefType newResultType = SubViewOp::inferRankReducedResultType(
op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
- op.getMixedSizes(), op.getMixedStrides()));
+ op.getMixedSizes(), op.getMixedStrides());
Value newSubview = rewriter.create<SubViewOp>(
op.getLoc(), newResultType, conversionOp.getOperand(0),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index bc0dd034f63851..c475d92e0658e5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -60,14 +60,13 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
// `subview(old_op)` is replaced by a new `subview(val)`.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(subviewUse);
- Type newType = memref::SubViewOp::inferRankReducedResultType(
+ MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
subviewUse.getStaticStrides());
Value newSubview = rewriter.create<memref::SubViewOp>(
- subviewUse->getLoc(), cast<MemRefType>(newType), val,
- subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
- subviewUse.getMixedStrides());
+ subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(),
+ subviewUse.getMixedSizes(), subviewUse.getMixedStrides());
// Ouch recursion ... is this really necessary?
replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
@@ -211,9 +210,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
// Strides is [1, 1 ... 1 ].
- auto dstMemref =
- cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
- originalShape, mbMemRefType, offsets, sizes, strides));
+ MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
+ originalShape, mbMemRefType, offsets, sizes, strides);
Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9797b73f534a96..35862c74c57552 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -407,10 +407,10 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
- return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ return memref::SubViewOp::inferRankReducedResultType(
extractSliceOp.getType().getShape(),
llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
- mixedStrides));
+ mixedStrides);
}
};
@@ -692,10 +692,10 @@ struct InsertSliceOpInterface
// Take a subview of the destination buffer.
auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
- auto subviewMemRefType =
- cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ MemRefType subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getShape(), dstMemrefType,
- mixedOffsets, mixedSizes, mixedStrides));
+ mixedOffsets, mixedSizes, mixedStrides);
Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
mixedStrides);
@@ -960,12 +960,12 @@ struct ParallelInsertSliceOpInterface
// Take a subview of the destination buffer.
auto destBufferType = cast<MemRefType>(destBuffer->getType());
- auto subviewMemRefType =
- cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ MemRefType subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
parallelInsertSliceOp.getMixedOffsets(),
parallelInsertSliceOp.getMixedSizes(),
- parallelInsertSliceOp.getMixedStrides()));
+ parallelInsertSliceOp.getMixedStrides());
Value subview = rewriter.create<memref::SubViewOp>(
parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
parallelInsertSliceOp.getMixedOffsets(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index bd5f06a3b46d42..b124ea32af1b32 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -265,9 +265,9 @@ static MemRefType dropUnitDims(MemRefType inputType,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
auto targetShape = getReducedShape(sizes);
- Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
+ MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
targetShape, inputType, offsets, sizes, strides);
- return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
+ return canonicalizeStridedLayout(rankReducedType);
}
/// Creates a rank-reducing memref.subview op that drops unit dims from its
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 20cd9cba6909a6..3f3e9ae9df2865 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1319,10 +1319,9 @@ class DropInnerMostUnitDimsTransferRead
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
- auto resultMemrefType =
- cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
- srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
- strides));
+ MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
+ srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
+ strides);
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
@@ -1410,10 +1409,9 @@ class DropInnerMostUnitDimsTransferWrite
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
- auto resultMemrefType =
- cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
- srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
- strides));
+ MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
+ srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
+ strides);
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need that patch? Can you also put the reason to PR description?
Sure, but I'm not sure there's much to elaborate on, it's just nice to avoid casts and rely more on static type checks. Updated description. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! IMO, it is better than empty description.
@tlongeri , do you need help landing this? If you don't have commit access, please just rebase and I can land it for you. |
556c406
to
091e126
Compare
… functions (llvm#120024) Avoids the need for cast, and matches the extra build functions, which take a `MemRefType`
… functions (llvm#120024) Avoids the need for cast, and matches the extra build functions, which take a `MemRefType`
Avoids the need for cast, and matches the extra build functions, which take a
MemRefType