Skip to content

[MLIR][NFC] Return MemRefType in memref.subview return type inference functions #120024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2081,14 +2081,14 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
static Type inferResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
static Type inferResultType(MemRefType sourceMemRefType,
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
ArrayRef<OpFoldResult> staticStrides);
static MemRefType inferResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
static MemRefType inferResultType(MemRefType sourceMemRefType,
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
ArrayRef<OpFoldResult> staticStrides);

/// A rank-reducing result type can be inferred from the desired result
/// shape. Only the layout map is inferred.
Expand All @@ -2097,16 +2097,16 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
/// and the desired sizes. In case there are more "ones" among the sizes
/// than the difference in source/result rank, it is not clear which dims of
/// size one should be dropped.
static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
MemRefType sourceMemRefType,
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
ArrayRef<OpFoldResult> staticStrides);
static MemRefType inferRankReducedResultType(
ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
static MemRefType inferRankReducedResultType(
ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
ArrayRef<OpFoldResult> staticStrides);

/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
Expand Down
61 changes: 29 additions & 32 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2702,10 +2702,10 @@ void SubViewOp::getAsmResultNames(
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides) {
MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides) {
unsigned rank = sourceMemRefType.getRank();
(void)rank;
assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
Expand Down Expand Up @@ -2744,10 +2744,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
sourceMemRefType.getMemorySpace());
}

Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
Expand All @@ -2763,13 +2763,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
staticSizes, staticStrides);
}

Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
MemRefType sourceRankedTensorType,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
auto inferredType = llvm::cast<MemRefType>(
inferResultType(sourceRankedTensorType, offsets, sizes, strides));
MemRefType SubViewOp::inferRankReducedResultType(
ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
MemRefType inferredType =
inferResultType(sourceRankedTensorType, offsets, sizes, strides);
assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
"expected ");
if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
Expand All @@ -2795,11 +2794,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
inferredType.getMemorySpace());
}

Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
MemRefType sourceRankedTensorType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
MemRefType SubViewOp::inferRankReducedResultType(
ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
Expand All @@ -2826,8 +2824,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
sourceMemRefType, staticOffsets, staticSizes, staticStrides));
resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
staticSizes, staticStrides);
}
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
Expand Down Expand Up @@ -2992,8 +2990,8 @@ LogicalResult SubViewOp::verify() {

// Compute the expected result type, assuming that there are no rank
// reductions.
auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
MemRefType expectedType = SubViewOp::inferResultType(
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());

// Verify all properties of a shaped type: rank, element type and dimension
// sizes. This takes into account potential rank reductions.
Expand Down Expand Up @@ -3075,8 +3073,8 @@ static MemRefType getCanonicalSubViewResultType(
MemRefType currentResultType, MemRefType currentSourceType,
MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
sourceType, mixedOffsets, mixedSizes, mixedStrides));
MemRefType nonRankReducedType = SubViewOp::inferResultType(
sourceType, mixedOffsets, mixedSizes, mixedStrides);
FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
currentSourceType, currentResultType, mixedSizes);
if (failed(unusedDims))
Expand Down Expand Up @@ -3110,9 +3108,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
auto targetType =
llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
targetShape, memrefType, offsets, sizes, strides));
MemRefType targetType = SubViewOp::inferRankReducedResultType(
targetShape, memrefType, offsets, sizes, strides);
return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
sizes, strides);
}
Expand Down Expand Up @@ -3256,11 +3253,11 @@ struct SubViewReturnTypeCanonicalizer {
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
// Infer a memref type without taking into account any rank reductions.
auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
mixedSizes, mixedStrides);
MemRefType resTy = SubViewOp::inferResultType(
op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
if (!resTy)
return {};
MemRefType nonReducedType = cast<MemRefType>(resTy);
MemRefType nonReducedType = resTy;

// Directly return the non-rank reduced type if there are no dropped dims.
llvm::SmallBitVector droppedDims = op.getDroppedDims();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ propagateSubViewOp(RewriterBase &rewriter,
UnrealizedConversionCastOp conversionOp, SubViewOp op) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
MemRefType newResultType = SubViewOp::inferRankReducedResultType(
op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
op.getMixedSizes(), op.getMixedStrides()));
op.getMixedSizes(), op.getMixedStrides());
Value newSubview = rewriter.create<SubViewOp>(
op.getLoc(), newResultType, conversionOp.getOperand(0),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
Expand Down
12 changes: 5 additions & 7 deletions mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,13 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
// `subview(old_op)` is replaced by a new `subview(val)`.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(subviewUse);
Type newType = memref::SubViewOp::inferRankReducedResultType(
MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
subviewUse.getStaticStrides());
Value newSubview = rewriter.create<memref::SubViewOp>(
subviewUse->getLoc(), cast<MemRefType>(newType), val,
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
subviewUse.getMixedStrides());
subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(),
subviewUse.getMixedSizes(), subviewUse.getMixedStrides());

// Ouch recursion ... is this really necessary?
replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
Expand Down Expand Up @@ -211,9 +210,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
// Strides is [1, 1 ... 1 ].
auto dstMemref =
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
originalShape, mbMemRefType, offsets, sizes, strides));
MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
originalShape, mbMemRefType, offsets, sizes, strides);
Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,10 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
return memref::SubViewOp::inferRankReducedResultType(
extractSliceOp.getType().getShape(),
llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
mixedStrides));
mixedStrides);
}
};

Expand Down Expand Up @@ -692,10 +692,10 @@ struct InsertSliceOpInterface

// Take a subview of the destination buffer.
auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
auto subviewMemRefType =
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
MemRefType subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getShape(), dstMemrefType,
mixedOffsets, mixedSizes, mixedStrides));
mixedOffsets, mixedSizes, mixedStrides);
Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
mixedStrides);
Expand Down Expand Up @@ -960,12 +960,12 @@ struct ParallelInsertSliceOpInterface

// Take a subview of the destination buffer.
auto destBufferType = cast<MemRefType>(destBuffer->getType());
auto subviewMemRefType =
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
MemRefType subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
parallelInsertSliceOp.getMixedOffsets(),
parallelInsertSliceOp.getMixedSizes(),
parallelInsertSliceOp.getMixedStrides()));
parallelInsertSliceOp.getMixedStrides());
Value subview = rewriter.create<memref::SubViewOp>(
parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
parallelInsertSliceOp.getMixedOffsets(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ static MemRefType dropUnitDims(MemRefType inputType,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
auto targetShape = getReducedShape(sizes);
Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
targetShape, inputType, offsets, sizes, strides);
return cast<MemRefType>(rankReducedType).canonicalizeStridedLayout();
return rankReducedType.canonicalizeStridedLayout();
}

/// Creates a rank-reducing memref.subview op that drops unit dims from its
Expand Down
14 changes: 6 additions & 8 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1326,10 +1326,9 @@ class DropInnerMostUnitDimsTransferRead
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
auto resultMemrefType =
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
strides));
MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
strides);
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
Expand Down Expand Up @@ -1417,10 +1416,9 @@ class DropInnerMostUnitDimsTransferWrite
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
auto resultMemrefType =
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
strides));
MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
strides);
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));

Expand Down