@@ -31,23 +31,17 @@ namespace {
31
31
namespace saturated_arith {
32
32
struct Wrapper {
33
33
static Wrapper stride (int64_t v) {
34
- return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 }
35
- : Wrapper{false , v};
34
+ return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 } : Wrapper{false , v};
36
35
}
37
36
static Wrapper offset (int64_t v) {
38
- return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 }
39
- : Wrapper{false , v};
37
+ return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 } : Wrapper{false , v};
40
38
}
41
39
static Wrapper size (int64_t v) {
42
40
return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 } : Wrapper{false , v};
43
41
}
44
- int64_t asOffset () {
45
- return saturated ? ShapedType::kDynamic : v;
46
- }
42
+ int64_t asOffset () { return saturated ? ShapedType::kDynamic : v; }
47
43
int64_t asSize () { return saturated ? ShapedType::kDynamic : v; }
48
- int64_t asStride () {
49
- return saturated ? ShapedType::kDynamic : v;
50
- }
44
+ int64_t asStride () { return saturated ? ShapedType::kDynamic : v; }
51
45
bool operator ==(Wrapper other) {
52
46
return (saturated && other.saturated ) ||
53
47
(!saturated && !other.saturated && v == other.v );
@@ -731,8 +725,7 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
731
725
for (auto it : llvm::zip (sourceStrides, resultStrides)) {
732
726
auto ss = std::get<0 >(it), st = std::get<1 >(it);
733
727
if (ss != st)
734
- if (ShapedType::isDynamic (ss) &&
735
- !ShapedType::isDynamic (st))
728
+ if (ShapedType::isDynamic (ss) && !ShapedType::isDynamic (st))
736
729
return false ;
737
730
}
738
731
@@ -765,8 +758,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
765
758
// same. They are also compatible if either one is dynamic (see
766
759
// description of MemRefCastOp for details).
767
760
auto checkCompatible = [](int64_t a, int64_t b) {
768
- return (ShapedType::isDynamic (a) ||
769
- ShapedType::isDynamic (b) || a == b);
761
+ return (ShapedType::isDynamic (a) || ShapedType::isDynamic (b) || a == b);
770
762
};
771
763
if (!checkCompatible (aOffset, bOffset))
772
764
return false ;
@@ -1889,8 +1881,7 @@ LogicalResult ReinterpretCastOp::verify() {
1889
1881
// Match offset in result memref type and in static_offsets attribute.
1890
1882
int64_t expectedOffset = getStaticOffsets ().front ();
1891
1883
if (!ShapedType::isDynamic (resultOffset) &&
1892
- !ShapedType::isDynamic (expectedOffset) &&
1893
- resultOffset != expectedOffset)
1884
+ !ShapedType::isDynamic (expectedOffset) && resultOffset != expectedOffset)
1894
1885
return emitError (" expected result type with offset = " )
1895
1886
<< expectedOffset << " instead of " << resultOffset;
1896
1887
@@ -2944,18 +2935,6 @@ static MemRefType getCanonicalSubViewResultType(
2944
2935
nonRankReducedType.getMemorySpace ());
2945
2936
}
2946
2937
2947
- // / Compute the canonical result type of a SubViewOp. Call `inferResultType`
2948
- // / to deduce the result type. Additionally, reduce the rank of the inferred
2949
- // / result type if `currentResultType` is lower rank than `sourceType`.
2950
- static MemRefType getCanonicalSubViewResultType (
2951
- MemRefType currentResultType, MemRefType sourceType,
2952
- ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2953
- ArrayRef<OpFoldResult> mixedStrides) {
2954
- return getCanonicalSubViewResultType (currentResultType, sourceType,
2955
- sourceType, mixedOffsets, mixedSizes,
2956
- mixedStrides);
2957
- }
2958
-
2959
2938
Value mlir::memref::createCanonicalRankReducingSubViewOp (
2960
2939
OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t > targetShape) {
2961
2940
auto memrefType = llvm::cast<MemRefType>(memref.getType ());
@@ -3108,9 +3087,32 @@ struct SubViewReturnTypeCanonicalizer {
3108
3087
MemRefType operator ()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3109
3088
ArrayRef<OpFoldResult> mixedSizes,
3110
3089
ArrayRef<OpFoldResult> mixedStrides) {
3111
- return getCanonicalSubViewResultType (op.getType (), op.getSourceType (),
3112
- mixedOffsets, mixedSizes,
3113
- mixedStrides);
3090
+ // Infer a memref type without taking into account any rank reductions.
3091
+ MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType (
3092
+ op.getSourceType (), mixedOffsets, mixedSizes, mixedStrides));
3093
+
3094
+ // Directly return the non-rank reduced type if there are no dropped dims.
3095
+ llvm::SmallBitVector droppedDims = op.getDroppedDims ();
3096
+ if (droppedDims.empty ())
3097
+ return nonReducedType;
3098
+
3099
+ // Take the strides and offset from the non-rank reduced type.
3100
+ auto [nonReducedStrides, offset] = getStridesAndOffset (nonReducedType);
3101
+
3102
+ // Drop dims from shape and strides.
3103
+ SmallVector<int64_t > targetShape;
3104
+ SmallVector<int64_t > targetStrides;
3105
+ for (int64_t i = 0 ; i < static_cast <int64_t >(mixedSizes.size ()); ++i) {
3106
+ if (droppedDims.test (i))
3107
+ continue ;
3108
+ targetStrides.push_back (nonReducedStrides[i]);
3109
+ targetShape.push_back (nonReducedType.getDimSize (i));
3110
+ }
3111
+
3112
+ return MemRefType::get (targetShape, nonReducedType.getElementType (),
3113
+ StridedLayoutAttr::get (nonReducedType.getContext (),
3114
+ offset, targetStrides),
3115
+ nonReducedType.getMemorySpace ());
3114
3116
}
3115
3117
};
3116
3118
0 commit comments