@@ -30,61 +30,26 @@ using namespace linalg;
30
30
// StructuredOp specific helpers.
31
31
// ===----------------------------------------------------------------------===//
32
32
33
- // / Relate the producer to the consumer loop iterations that access the same
34
- // / producer result element:
35
- // / consumerToProducerLoops =
36
- // / inverse(producerIndexingMap).compose(consumerIndexingMap).
37
- // / Return `consumerToProducerLoops` or none if the inversion fails.
38
- static Optional<AffineMap>
39
- getConsumerToProducerLoopsMap (AffineMap producerIndexingMap,
40
- AffineMap consumerIndexingMap) {
41
- assert (consumerIndexingMap.getNumResults () ==
42
- producerIndexingMap.getNumResults () &&
43
- " expect the number of indexing map results to match" );
44
- // Ensure the producer indexing map is a projected permutation.
45
- if (!producerIndexingMap.isProjectedPermutation ())
46
- return None;
47
- AffineMap inverseIndexingMap =
48
- inverseAndBroadcastProjectedPermuation (producerIndexingMap);
49
- return inverseIndexingMap.compose (consumerIndexingMap);
50
- }
51
-
52
- // / Returns the producer result slice dimensions tiled by the tile loop nest or
53
- // / an empty vector if `getConsumerToProducerLoopsMap` returns none.
54
- // TODO: replace by Fourier-Motzkin and/or compute starting from consumer.
55
- SmallVector<int64_t > getTiledSliceDims (OpResult producerResult,
56
- OpOperand *consumerOperand,
33
+ // / Returns the tiled slice dimensions given the tiled consumer loop dimensions.
34
+ // / The slice defines a hyper rectangular iteration space and fusing the
35
+ // / producer is always possible. However, depending on the consumer indexing
36
+ // / map, not all slice elements may be consumed and the tiles may overlap. In
37
+ // / these cases, fusion introduces redundant computation.
38
+ SmallVector<int64_t > getTiledSliceDims (OpOperand *consumerOperand,
57
39
ArrayRef<int64_t > tiledLoopDims) {
40
+ // Get the consumer operand indexing map.
58
41
LinalgOp consumerOp = consumerOperand->getOwner ();
59
- LinalgOp producerOp = producerResult.getOwner ();
60
- OpOperand *opOperand =
61
- producerOp.getOutputOperand (producerResult.getResultNumber ());
62
-
63
- // Compute the `consumerToProducerLoopsMap` and exit if the computation fails.
64
- AffineMap producerIndexingMap = producerOp.getTiedIndexingMap (opOperand);
65
- Optional<AffineMap> consumerToProducerLoopsMap =
66
- getConsumerToProducerLoopsMap (
67
- producerIndexingMap, consumerOp.getTiedIndexingMap (consumerOperand));
68
- if (!consumerToProducerLoopsMap.hasValue ())
69
- return {};
70
-
71
- // Compute the set of tiled producer loops.
72
- DenseSet<int64_t > tiledProducerLoops;
73
- for (auto en : enumerate(consumerToProducerLoopsMap->getResults ())) {
74
- for (int64_t dim : tiledLoopDims) {
75
- if (en.value ().isFunctionOfDim (dim))
76
- tiledProducerLoops.insert (en.index ());
42
+ AffineMap indexingMap = consumerOp.getTiedIndexingMap (consumerOperand);
43
+
44
+ // Search the slice dimensions tiled by a tile loop dimension.
45
+ DenseSet<int64_t > tiledSliceDims;
46
+ for (auto en : enumerate(indexingMap.getResults ())) {
47
+ for (auto tiledLoopDim : tiledLoopDims) {
48
+ if (en.value ().isFunctionOfDim (tiledLoopDim))
49
+ tiledSliceDims.insert (en.index ());
77
50
}
78
51
}
79
-
80
- // Compute the slice dimensions for the tiled producer loops.
81
- SmallVector<int64_t > tiledSliceDims;
82
- for (auto en : enumerate(producerIndexingMap.getResults ())) {
83
- auto dimExpr = en.value ().dyn_cast <AffineDimExpr>();
84
- if (dimExpr && tiledProducerLoops.count (dimExpr.getPosition ()) != 0 )
85
- tiledSliceDims.push_back (en.index ());
86
- }
87
- return tiledSliceDims;
52
+ return {tiledSliceDims.begin (), tiledSliceDims.end ()};
88
53
}
89
54
90
55
// / Returns the producer fused in place of `sliceOp`. Tile the producer operands
@@ -332,9 +297,10 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
332
297
if (!producerResult || !isa<LinalgOp>(producerResult.getOwner ()))
333
298
return failure ();
334
299
335
- // Compute the slice dimensions tiled by `tileLoopNest`.
300
+ // Compute the tiled producer slice dimensions given the tiled root operation
301
+ // loop dimensions `loopDims`.
336
302
SmallVector<int64_t > tiledSliceDims =
337
- getTiledSliceDims (producerResult, rootOpOperand, loopDims);
303
+ getTiledSliceDims (rootOpOperand, loopDims);
338
304
if (tiledSliceDims.empty ())
339
305
return failure ();
340
306
0 commit comments