Skip to content

Commit 556c406

Browse files
committed
[MLIR][NFC] Return MemRefType in memref.subview return type inference functions
1 parent 8345a95 commit 556c406

File tree

7 files changed

+70
-77
lines changed

7 files changed

+70
-77
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2079,14 +2079,14 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20792079
/// A subview result type can be fully inferred from the source type and the
20802080
/// static representation of offsets, sizes and strides. Special sentinels
20812081
/// encode the dynamic case.
2082-
static Type inferResultType(MemRefType sourceMemRefType,
2083-
ArrayRef<int64_t> staticOffsets,
2084-
ArrayRef<int64_t> staticSizes,
2085-
ArrayRef<int64_t> staticStrides);
2086-
static Type inferResultType(MemRefType sourceMemRefType,
2087-
ArrayRef<OpFoldResult> staticOffsets,
2088-
ArrayRef<OpFoldResult> staticSizes,
2089-
ArrayRef<OpFoldResult> staticStrides);
2082+
static MemRefType inferResultType(MemRefType sourceMemRefType,
2083+
ArrayRef<int64_t> staticOffsets,
2084+
ArrayRef<int64_t> staticSizes,
2085+
ArrayRef<int64_t> staticStrides);
2086+
static MemRefType inferResultType(MemRefType sourceMemRefType,
2087+
ArrayRef<OpFoldResult> staticOffsets,
2088+
ArrayRef<OpFoldResult> staticSizes,
2089+
ArrayRef<OpFoldResult> staticStrides);
20902090

20912091
/// A rank-reducing result type can be inferred from the desired result
20922092
/// shape. Only the layout map is inferred.
@@ -2095,16 +2095,16 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20952095
/// and the desired sizes. In case there are more "ones" among the sizes
20962096
/// than the difference in source/result rank, it is not clear which dims of
20972097
/// size one should be dropped.
2098-
static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2099-
MemRefType sourceMemRefType,
2100-
ArrayRef<int64_t> staticOffsets,
2101-
ArrayRef<int64_t> staticSizes,
2102-
ArrayRef<int64_t> staticStrides);
2103-
static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2104-
MemRefType sourceMemRefType,
2105-
ArrayRef<OpFoldResult> staticOffsets,
2106-
ArrayRef<OpFoldResult> staticSizes,
2107-
ArrayRef<OpFoldResult> staticStrides);
2098+
static MemRefType inferRankReducedResultType(
2099+
ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
2100+
ArrayRef<int64_t> staticOffsets,
2101+
ArrayRef<int64_t> staticSizes,
2102+
ArrayRef<int64_t> staticStrides);
2103+
static MemRefType inferRankReducedResultType(
2104+
ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
2105+
ArrayRef<OpFoldResult> staticOffsets,
2106+
ArrayRef<OpFoldResult> staticSizes,
2107+
ArrayRef<OpFoldResult> staticStrides);
21082108

21092109
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
21102110
/// and `static_strides` attributes.

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2697,10 +2697,10 @@ void SubViewOp::getAsmResultNames(
26972697
/// A subview result type can be fully inferred from the source type and the
26982698
/// static representation of offsets, sizes and strides. Special sentinels
26992699
/// encode the dynamic case.
2700-
Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2701-
ArrayRef<int64_t> staticOffsets,
2702-
ArrayRef<int64_t> staticSizes,
2703-
ArrayRef<int64_t> staticStrides) {
2700+
MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2701+
ArrayRef<int64_t> staticOffsets,
2702+
ArrayRef<int64_t> staticSizes,
2703+
ArrayRef<int64_t> staticStrides) {
27042704
unsigned rank = sourceMemRefType.getRank();
27052705
(void)rank;
27062706
assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
@@ -2739,10 +2739,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
27392739
sourceMemRefType.getMemorySpace());
27402740
}
27412741

2742-
Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2743-
ArrayRef<OpFoldResult> offsets,
2744-
ArrayRef<OpFoldResult> sizes,
2745-
ArrayRef<OpFoldResult> strides) {
2742+
MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2743+
ArrayRef<OpFoldResult> offsets,
2744+
ArrayRef<OpFoldResult> sizes,
2745+
ArrayRef<OpFoldResult> strides) {
27462746
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
27472747
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
27482748
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2758,13 +2758,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
27582758
staticSizes, staticStrides);
27592759
}
27602760

2761-
Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2762-
MemRefType sourceRankedTensorType,
2763-
ArrayRef<int64_t> offsets,
2764-
ArrayRef<int64_t> sizes,
2765-
ArrayRef<int64_t> strides) {
2766-
auto inferredType = llvm::cast<MemRefType>(
2767-
inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2761+
MemRefType SubViewOp::inferRankReducedResultType(
2762+
ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2763+
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2764+
ArrayRef<int64_t> strides) {
2765+
MemRefType inferredType =
2766+
inferResultType(sourceRankedTensorType, offsets, sizes, strides);
27682767
assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
27692768
"expected ");
27702769
if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
@@ -2790,11 +2789,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
27902789
inferredType.getMemorySpace());
27912790
}
27922791

2793-
Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2794-
MemRefType sourceRankedTensorType,
2795-
ArrayRef<OpFoldResult> offsets,
2796-
ArrayRef<OpFoldResult> sizes,
2797-
ArrayRef<OpFoldResult> strides) {
2792+
MemRefType SubViewOp::inferRankReducedResultType(
2793+
ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2794+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2795+
ArrayRef<OpFoldResult> strides) {
27982796
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
27992797
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
28002798
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2821,8 +2819,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
28212819
auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
28222820
// Structuring implementation this way avoids duplication between builders.
28232821
if (!resultType) {
2824-
resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2825-
sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2822+
resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2823+
staticSizes, staticStrides);
28262824
}
28272825
result.addAttributes(attrs);
28282826
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2987,8 +2985,8 @@ LogicalResult SubViewOp::verify() {
29872985

29882986
// Compute the expected result type, assuming that there are no rank
29892987
// reductions.
2990-
auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2991-
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2988+
MemRefType expectedType = SubViewOp::inferResultType(
2989+
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
29922990

29932991
// Verify all properties of a shaped type: rank, element type and dimension
29942992
// sizes. This takes into account potential rank reductions.
@@ -3070,8 +3068,8 @@ static MemRefType getCanonicalSubViewResultType(
30703068
MemRefType currentResultType, MemRefType currentSourceType,
30713069
MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
30723070
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3073-
auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
3074-
sourceType, mixedOffsets, mixedSizes, mixedStrides));
3071+
MemRefType nonRankReducedType = SubViewOp::inferResultType(
3072+
sourceType, mixedOffsets, mixedSizes, mixedStrides);
30753073
FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
30763074
currentSourceType, currentResultType, mixedSizes);
30773075
if (failed(unusedDims))
@@ -3105,9 +3103,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
31053103
SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
31063104
SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
31073105
SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3108-
auto targetType =
3109-
llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
3110-
targetShape, memrefType, offsets, sizes, strides));
3106+
MemRefType targetType = SubViewOp::inferRankReducedResultType(
3107+
targetShape, memrefType, offsets, sizes, strides);
31113108
return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
31123109
sizes, strides);
31133110
}
@@ -3251,11 +3248,11 @@ struct SubViewReturnTypeCanonicalizer {
32513248
ArrayRef<OpFoldResult> mixedSizes,
32523249
ArrayRef<OpFoldResult> mixedStrides) {
32533250
// Infer a memref type without taking into account any rank reductions.
3254-
auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3255-
mixedSizes, mixedStrides);
3251+
MemRefType resTy = SubViewOp::inferResultType(
3252+
op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
32563253
if (!resTy)
32573254
return {};
3258-
MemRefType nonReducedType = cast<MemRefType>(resTy);
3255+
MemRefType nonReducedType = resTy;
32593256

32603257
// Directly return the non-rank reduced type if there are no dropped dims.
32613258
llvm::SmallBitVector droppedDims = op.getDroppedDims();

mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ propagateSubViewOp(RewriterBase &rewriter,
7070
UnrealizedConversionCastOp conversionOp, SubViewOp op) {
7171
OpBuilder::InsertionGuard g(rewriter);
7272
rewriter.setInsertionPoint(op);
73-
auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
73+
MemRefType newResultType = SubViewOp::inferRankReducedResultType(
7474
op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
75-
op.getMixedSizes(), op.getMixedStrides()));
75+
op.getMixedSizes(), op.getMixedStrides());
7676
Value newSubview = rewriter.create<SubViewOp>(
7777
op.getLoc(), newResultType, conversionOp.getOperand(0),
7878
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());

mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,13 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
6060
// `subview(old_op)` is replaced by a new `subview(val)`.
6161
OpBuilder::InsertionGuard g(rewriter);
6262
rewriter.setInsertionPoint(subviewUse);
63-
Type newType = memref::SubViewOp::inferRankReducedResultType(
63+
MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
6464
subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
6565
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
6666
subviewUse.getStaticStrides());
6767
Value newSubview = rewriter.create<memref::SubViewOp>(
68-
subviewUse->getLoc(), cast<MemRefType>(newType), val,
69-
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
70-
subviewUse.getMixedStrides());
68+
subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(),
69+
subviewUse.getMixedSizes(), subviewUse.getMixedStrides());
7170

7271
// Ouch recursion ... is this really necessary?
7372
replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
@@ -211,9 +210,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
211210
for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
212211
sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
213212
// Strides is [1, 1 ... 1 ].
214-
auto dstMemref =
215-
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
216-
originalShape, mbMemRefType, offsets, sizes, strides));
213+
MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
214+
originalShape, mbMemRefType, offsets, sizes, strides);
217215
Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
218216
offsets, sizes, strides);
219217
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,10 @@ struct ExtractSliceOpInterface
407407
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
408408
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
409409
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
410-
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
410+
return memref::SubViewOp::inferRankReducedResultType(
411411
extractSliceOp.getType().getShape(),
412412
llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
413-
mixedStrides));
413+
mixedStrides);
414414
}
415415
};
416416

@@ -692,10 +692,10 @@ struct InsertSliceOpInterface
692692

693693
// Take a subview of the destination buffer.
694694
auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
695-
auto subviewMemRefType =
696-
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
695+
MemRefType subviewMemRefType =
696+
memref::SubViewOp::inferRankReducedResultType(
697697
insertSliceOp.getSourceType().getShape(), dstMemrefType,
698-
mixedOffsets, mixedSizes, mixedStrides));
698+
mixedOffsets, mixedSizes, mixedStrides);
699699
Value subView = rewriter.create<memref::SubViewOp>(
700700
loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
701701
mixedStrides);
@@ -960,12 +960,12 @@ struct ParallelInsertSliceOpInterface
960960

961961
// Take a subview of the destination buffer.
962962
auto destBufferType = cast<MemRefType>(destBuffer->getType());
963-
auto subviewMemRefType =
964-
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
963+
MemRefType subviewMemRefType =
964+
memref::SubViewOp::inferRankReducedResultType(
965965
parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
966966
parallelInsertSliceOp.getMixedOffsets(),
967967
parallelInsertSliceOp.getMixedSizes(),
968-
parallelInsertSliceOp.getMixedStrides()));
968+
parallelInsertSliceOp.getMixedStrides());
969969
Value subview = rewriter.create<memref::SubViewOp>(
970970
parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
971971
parallelInsertSliceOp.getMixedOffsets(),

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ static MemRefType dropUnitDims(MemRefType inputType,
265265
ArrayRef<OpFoldResult> sizes,
266266
ArrayRef<OpFoldResult> strides) {
267267
auto targetShape = getReducedShape(sizes);
268-
Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
268+
MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
269269
targetShape, inputType, offsets, sizes, strides);
270-
return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
270+
return canonicalizeStridedLayout(rankReducedType);
271271
}
272272

273273
/// Creates a rank-reducing memref.subview op that drops unit dims from its

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,10 +1319,9 @@ class DropInnerMostUnitDimsTransferRead
13191319
rewriter.getIndexAttr(0));
13201320
SmallVector<OpFoldResult> strides(srcType.getRank(),
13211321
rewriter.getIndexAttr(1));
1322-
auto resultMemrefType =
1323-
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1324-
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1325-
strides));
1322+
MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1323+
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1324+
strides);
13261325
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
13271326
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
13281327
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
@@ -1410,10 +1409,9 @@ class DropInnerMostUnitDimsTransferWrite
14101409
rewriter.getIndexAttr(0));
14111410
SmallVector<OpFoldResult> strides(srcType.getRank(),
14121411
rewriter.getIndexAttr(1));
1413-
auto resultMemrefType =
1414-
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1415-
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1416-
strides));
1412+
MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1413+
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1414+
strides);
14171415
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
14181416
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
14191417

0 commit comments

Comments
 (0)