@@ -2697,10 +2697,10 @@ void SubViewOp::getAsmResultNames(
2697
2697
// / A subview result type can be fully inferred from the source type and the
2698
2698
// / static representation of offsets, sizes and strides. Special sentinels
2699
2699
// / encode the dynamic case.
2700
- Type SubViewOp::inferResultType (MemRefType sourceMemRefType,
2701
- ArrayRef<int64_t > staticOffsets,
2702
- ArrayRef<int64_t > staticSizes,
2703
- ArrayRef<int64_t > staticStrides) {
2700
+ MemRefType SubViewOp::inferResultType (MemRefType sourceMemRefType,
2701
+ ArrayRef<int64_t > staticOffsets,
2702
+ ArrayRef<int64_t > staticSizes,
2703
+ ArrayRef<int64_t > staticStrides) {
2704
2704
unsigned rank = sourceMemRefType.getRank ();
2705
2705
(void )rank;
2706
2706
assert (staticOffsets.size () == rank && " staticOffsets length mismatch" );
@@ -2739,10 +2739,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2739
2739
sourceMemRefType.getMemorySpace ());
2740
2740
}
2741
2741
2742
- Type SubViewOp::inferResultType (MemRefType sourceMemRefType,
2743
- ArrayRef<OpFoldResult> offsets,
2744
- ArrayRef<OpFoldResult> sizes,
2745
- ArrayRef<OpFoldResult> strides) {
2742
+ MemRefType SubViewOp::inferResultType (MemRefType sourceMemRefType,
2743
+ ArrayRef<OpFoldResult> offsets,
2744
+ ArrayRef<OpFoldResult> sizes,
2745
+ ArrayRef<OpFoldResult> strides) {
2746
2746
SmallVector<int64_t > staticOffsets, staticSizes, staticStrides;
2747
2747
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2748
2748
dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
@@ -2758,13 +2758,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2758
2758
staticSizes, staticStrides);
2759
2759
}
2760
2760
2761
- Type SubViewOp::inferRankReducedResultType (ArrayRef<int64_t > resultShape,
2762
- MemRefType sourceRankedTensorType,
2763
- ArrayRef<int64_t > offsets,
2764
- ArrayRef<int64_t > sizes,
2765
- ArrayRef<int64_t > strides) {
2766
- auto inferredType = llvm::cast<MemRefType>(
2767
- inferResultType (sourceRankedTensorType, offsets, sizes, strides));
2761
+ MemRefType SubViewOp::inferRankReducedResultType (
2762
+ ArrayRef<int64_t > resultShape, MemRefType sourceRankedTensorType,
2763
+ ArrayRef<int64_t > offsets, ArrayRef<int64_t > sizes,
2764
+ ArrayRef<int64_t > strides) {
2765
+ MemRefType inferredType =
2766
+ inferResultType (sourceRankedTensorType, offsets, sizes, strides);
2768
2767
assert (inferredType.getRank () >= static_cast <int64_t >(resultShape.size ()) &&
2769
2768
" expected " );
2770
2769
if (inferredType.getRank () == static_cast <int64_t >(resultShape.size ()))
@@ -2790,11 +2789,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2790
2789
inferredType.getMemorySpace ());
2791
2790
}
2792
2791
2793
- Type SubViewOp::inferRankReducedResultType (ArrayRef<int64_t > resultShape,
2794
- MemRefType sourceRankedTensorType,
2795
- ArrayRef<OpFoldResult> offsets,
2796
- ArrayRef<OpFoldResult> sizes,
2797
- ArrayRef<OpFoldResult> strides) {
2792
+ MemRefType SubViewOp::inferRankReducedResultType (
2793
+ ArrayRef<int64_t > resultShape, MemRefType sourceRankedTensorType,
2794
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2795
+ ArrayRef<OpFoldResult> strides) {
2798
2796
SmallVector<int64_t > staticOffsets, staticSizes, staticStrides;
2799
2797
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2800
2798
dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
@@ -2821,8 +2819,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
2821
2819
auto sourceMemRefType = llvm::cast<MemRefType>(source.getType ());
2822
2820
// Structuring implementation this way avoids duplication between builders.
2823
2821
if (!resultType) {
2824
- resultType = llvm::cast<MemRefType>( SubViewOp::inferResultType (
2825
- sourceMemRefType, staticOffsets, staticSizes, staticStrides) );
2822
+ resultType = SubViewOp::inferResultType (sourceMemRefType, staticOffsets,
2823
+ staticSizes, staticStrides);
2826
2824
}
2827
2825
result.addAttributes (attrs);
2828
2826
build (b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2987,8 +2985,8 @@ LogicalResult SubViewOp::verify() {
2987
2985
2988
2986
// Compute the expected result type, assuming that there are no rank
2989
2987
// reductions.
2990
- auto expectedType = cast<MemRefType>( SubViewOp::inferResultType (
2991
- baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ())) ;
2988
+ MemRefType expectedType = SubViewOp::inferResultType (
2989
+ baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ());
2992
2990
2993
2991
// Verify all properties of a shaped type: rank, element type and dimension
2994
2992
// sizes. This takes into account potential rank reductions.
@@ -3070,8 +3068,8 @@ static MemRefType getCanonicalSubViewResultType(
3070
3068
MemRefType currentResultType, MemRefType currentSourceType,
3071
3069
MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3072
3070
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3073
- auto nonRankReducedType = llvm::cast<MemRefType>( SubViewOp::inferResultType (
3074
- sourceType, mixedOffsets, mixedSizes, mixedStrides)) ;
3071
+ MemRefType nonRankReducedType = SubViewOp::inferResultType (
3072
+ sourceType, mixedOffsets, mixedSizes, mixedStrides);
3075
3073
FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask (
3076
3074
currentSourceType, currentResultType, mixedSizes);
3077
3075
if (failed (unusedDims))
@@ -3105,9 +3103,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
3105
3103
SmallVector<OpFoldResult> offsets (rank, b.getIndexAttr (0 ));
3106
3104
SmallVector<OpFoldResult> sizes = getMixedSizes (b, loc, memref);
3107
3105
SmallVector<OpFoldResult> strides (rank, b.getIndexAttr (1 ));
3108
- auto targetType =
3109
- llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType (
3110
- targetShape, memrefType, offsets, sizes, strides));
3106
+ MemRefType targetType = SubViewOp::inferRankReducedResultType (
3107
+ targetShape, memrefType, offsets, sizes, strides);
3111
3108
return b.createOrFold <memref::SubViewOp>(loc, targetType, memref, offsets,
3112
3109
sizes, strides);
3113
3110
}
@@ -3251,11 +3248,11 @@ struct SubViewReturnTypeCanonicalizer {
3251
3248
ArrayRef<OpFoldResult> mixedSizes,
3252
3249
ArrayRef<OpFoldResult> mixedStrides) {
3253
3250
// Infer a memref type without taking into account any rank reductions.
3254
- auto resTy = SubViewOp::inferResultType (op. getSourceType (), mixedOffsets,
3255
- mixedSizes, mixedStrides);
3251
+ MemRefType resTy = SubViewOp::inferResultType (
3252
+ op. getSourceType (), mixedOffsets, mixedSizes, mixedStrides);
3256
3253
if (!resTy)
3257
3254
return {};
3258
- MemRefType nonReducedType = cast<MemRefType>( resTy) ;
3255
+ MemRefType nonReducedType = resTy;
3259
3256
3260
3257
// Directly return the non-rank reduced type if there are no dropped dims.
3261
3258
llvm::SmallBitVector droppedDims = op.getDroppedDims ();
0 commit comments