26
26
using namespace mlir ;
27
27
using namespace mlir ::memref;
28
28
29
- namespace {
30
- // / Idiomatic saturated operations on offsets, sizes and strides.
31
- namespace saturated_arith {
32
- struct Wrapper {
33
- static Wrapper stride (int64_t v) {
34
- return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 } : Wrapper{false , v};
35
- }
36
- static Wrapper offset (int64_t v) {
37
- return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 } : Wrapper{false , v};
38
- }
39
- static Wrapper size (int64_t v) {
40
- return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 } : Wrapper{false , v};
41
- }
42
- int64_t asOffset () { return saturated ? ShapedType::kDynamic : v; }
43
- int64_t asSize () { return saturated ? ShapedType::kDynamic : v; }
44
- int64_t asStride () { return saturated ? ShapedType::kDynamic : v; }
45
- bool operator ==(Wrapper other) {
46
- return (saturated && other.saturated ) ||
47
- (!saturated && !other.saturated && v == other.v );
48
- }
49
- bool operator !=(Wrapper other) { return !(*this == other); }
50
- Wrapper operator +(Wrapper other) {
51
- if (saturated || other.saturated )
52
- return Wrapper{true , 0 };
53
- return Wrapper{false , other.v + v};
54
- }
55
- Wrapper operator *(Wrapper other) {
56
- if (saturated || other.saturated )
57
- return Wrapper{true , 0 };
58
- return Wrapper{false , other.v * v};
59
- }
60
- bool saturated;
61
- int64_t v;
62
- };
63
- } // namespace saturated_arith
64
- } // namespace
65
-
66
29
// / Materialize a single constant operation from a given attribute value with
67
30
// / the desired resultant type.
68
31
Operation *MemRefDialect::materializeConstant (OpBuilder &builder,
@@ -2208,11 +2171,11 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2208
2171
ReassociationIndices reassoc = std::get<0 >(it);
2209
2172
int64_t currentStrideToExpand = std::get<1 >(it);
2210
2173
for (unsigned idx = 0 , e = reassoc.size (); idx < e; ++idx) {
2211
- using saturated_arith::Wrapper;
2212
2174
reverseResultStrides.push_back (currentStrideToExpand);
2213
- currentStrideToExpand = (Wrapper::stride (currentStrideToExpand) *
2214
- Wrapper::size (resultShape[shapeIndex--]))
2215
- .asStride ();
2175
+ currentStrideToExpand =
2176
+ (SaturatedInteger::wrap (currentStrideToExpand) *
2177
+ SaturatedInteger::wrap (resultShape[shapeIndex--]))
2178
+ .asInteger ();
2216
2179
}
2217
2180
}
2218
2181
auto resultStrides = llvm::to_vector<8 >(llvm::reverse (reverseResultStrides));
@@ -2332,10 +2295,9 @@ computeCollapsedLayoutMap(MemRefType srcType,
2332
2295
unsigned resultStrideIndex = resultStrides.size () - 1 ;
2333
2296
for (const ReassociationIndices &reassoc : llvm::reverse (reassociation)) {
2334
2297
auto trailingReassocs = ArrayRef<int64_t >(reassoc).drop_front ();
2335
- using saturated_arith::Wrapper;
2336
- auto stride = Wrapper::stride (resultStrides[resultStrideIndex--]);
2298
+ auto stride = SaturatedInteger::wrap (resultStrides[resultStrideIndex--]);
2337
2299
for (int64_t idx : llvm::reverse (trailingReassocs)) {
2338
- stride = stride * Wrapper::size (srcShape[idx]);
2300
+ stride = stride * SaturatedInteger::wrap (srcShape[idx]);
2339
2301
2340
2302
// Both source and result stride must have the same static value. In that
2341
2303
// case, we can be sure, that the dimensions are collapsible (because they
@@ -2345,7 +2307,7 @@ computeCollapsedLayoutMap(MemRefType srcType,
2345
2307
// ops where obviously non-contiguous dims are collapsed, but accept ops
2346
2308
// where we cannot be sure statically. Such ops may fail at runtime. See
2347
2309
// the op documentation for details.
2348
- auto srcStride = Wrapper::stride (srcStrides[idx - 1 ]);
2310
+ auto srcStride = SaturatedInteger::wrap (srcStrides[idx - 1 ]);
2349
2311
if (strict && (stride.saturated || srcStride.saturated ))
2350
2312
return failure ();
2351
2313
@@ -2371,11 +2333,11 @@ MemRefType CollapseShapeOp::computeCollapsedType(
2371
2333
SmallVector<int64_t > resultShape;
2372
2334
resultShape.reserve (reassociation.size ());
2373
2335
for (const ReassociationIndices &group : reassociation) {
2374
- using saturated_arith::Wrapper;
2375
- auto groupSize = Wrapper::size (1 );
2336
+ auto groupSize = SaturatedInteger::wrap (1 );
2376
2337
for (int64_t srcDim : group)
2377
- groupSize = groupSize * Wrapper::size (srcType.getDimSize (srcDim));
2378
- resultShape.push_back (groupSize.asSize ());
2338
+ groupSize =
2339
+ groupSize * SaturatedInteger::wrap (srcType.getDimSize (srcDim));
2340
+ resultShape.push_back (groupSize.asInteger ());
2379
2341
}
2380
2342
2381
2343
if (srcType.getLayout ().isIdentity ()) {
@@ -2586,11 +2548,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2586
2548
int64_t targetOffset = sourceOffset;
2587
2549
for (auto it : llvm::zip (staticOffsets, sourceStrides)) {
2588
2550
auto staticOffset = std::get<0 >(it), targetStride = std::get<1 >(it);
2589
- using saturated_arith::Wrapper;
2590
- targetOffset =
2591
- (Wrapper::offset (targetOffset) +
2592
- Wrapper::offset (staticOffset) * Wrapper::stride (targetStride))
2593
- .asOffset ();
2551
+ targetOffset = (SaturatedInteger::wrap (targetOffset) +
2552
+ SaturatedInteger::wrap (staticOffset) *
2553
+ SaturatedInteger::wrap (targetStride))
2554
+ .asInteger ();
2594
2555
}
2595
2556
2596
2557
// Compute target stride whose value is:
@@ -2599,10 +2560,9 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2599
2560
targetStrides.reserve (staticOffsets.size ());
2600
2561
for (auto it : llvm::zip (sourceStrides, staticStrides)) {
2601
2562
auto sourceStride = std::get<0 >(it), staticStride = std::get<1 >(it);
2602
- using saturated_arith::Wrapper;
2603
- targetStrides.push_back (
2604
- (Wrapper::stride (sourceStride) * Wrapper::stride (staticStride))
2605
- .asStride ());
2563
+ targetStrides.push_back ((SaturatedInteger::wrap (sourceStride) *
2564
+ SaturatedInteger::wrap (staticStride))
2565
+ .asInteger ());
2606
2566
}
2607
2567
2608
2568
// The type is now known.
0 commit comments