18
18
namespace mlir {
19
19
namespace linalg {
20
20
21
+ struct LinalgFusionOptions ;
21
22
struct LinalgTilingOptions ;
22
23
23
24
// ===----------------------------------------------------------------------===//
@@ -30,6 +31,14 @@ struct TiledLinalgOp {
30
31
SmallVector<Operation *, 8 > loops;
31
32
};
32
33
34
+ struct TiledAndFusedLinalgOps {
35
+ LinalgOp op;
36
+ SmallVector<LinalgOp, 1 > fusedProducers;
37
+ SmallVector<LinalgOp, 1 > originalProducers;
38
+ SmallVector<Operation *, 4 > fusedLoops;
39
+ SmallVector<Operation *, 4 > unfusedLoops;
40
+ };
41
+
33
42
// / Populates patterns for vectorization of all ConvN-D ops.
34
43
void populateConvVectorizationPatterns (
35
44
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@@ -53,6 +62,71 @@ void populateConvVectorizationPatterns(
53
62
Optional<TiledLinalgOp> tileLinalgOp (OpBuilder &b, LinalgOp op,
54
63
const LinalgTilingOptions &options);
55
64
65
+ // / Tile and fuse the `op` with its producers. The tile and fuse proceeds in
66
+ // / three steps
67
+ // / - Find tile loops that are fusable with its producer tile loops (a.k.a. tile
68
+ // / + fuse loops).
69
+ // / - Tile just these loops of the consumer (root operation) and fuse with
70
+ // / the producer.
71
+ // / - Tile again the tiled consumer operation produced above to do rest of
72
+ // / the tiling specified by the `tilingOptions`.
73
+ // /
74
+ // / For example, consider the sequence of matmul below
75
+ // /
76
+ // / linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>)
77
+ // / outs(%arg2 : memref<256x32xf32>)
78
+ // / linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>)
79
+ // / outs(%arg4 : memref<256x32xf32>)
80
+ // /
81
+ // / It is legal to fuse the RAW dependence (through %arg2) by only fusing the
82
+ // / matmuls row-wise. For example, the fused computation for the above is shown
83
+ // / below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling
84
+ // / along the rows of the matrix. The entire rows of the first matmul operation
85
+ // / need to be computed before they can be used for the second matmul. The
86
+ // / second matmul is further tiled (similar to normal tiling).
87
+ // /
88
+ // / #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
89
+ // / #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
90
+ // / scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) {
91
+ // / %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1]
92
+ // / : memref<256x32xf32> to memref<16x32xf32, #map0>
93
+ // / %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1]
94
+ // / : memref<256x32xf32> to memref<16x32xf32, #map0>
95
+ // / %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1]
96
+ // / : memref<256x32xf32> to memref<16x32xf32, #map0>
97
+ // / %3 = subview %arg1[0, 0] [32, 32] [1, 1]
98
+ // / : memref<32x32xf32> to memref<32x32xf32, #map1>
99
+ // / linalg.matmul
100
+ // / ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
101
+ // / outs(%0 : memref<16x32xf32, #map0>)
102
+ // / scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) {
103
+ // / scf.for %arg7 = %c0 to %c32 step %c4 {
104
+ // / %4 = subview %0[0, %arg7] [16, 4] [1, 1]
105
+ // / : memref<16x32xf32, #map0> to memref<16x4xf32, #map0>
106
+ // / %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1]
107
+ // / : memref<32x32xf32> to memref<4x8xf32, #map0>
108
+ // / %6 = subview %1[0, %arg6] [16, 8] [1, 1]
109
+ // / : memref<16x32xf32, #map0> to memref<16x8xf32, #map0>
110
+ // / linalg.matmul
111
+ // / ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
112
+ // / outs(%6 : memref<16x8xf32, #map0>)
113
+ // / }
114
+ // / scf.yield
115
+ // / }
116
+ // / scf.yield
117
+ // / }
118
+ // /
119
+ // / The following tiling options are handled differently in tile+fuse (compared
120
+ // / to tile only)
121
+ // / - Interchange of the tiling loops is not supported right now.
122
+ // / - Distribution is only done for the tile+fuse loops. The tiled loops
123
+ // / generated by the second tiling is not distributed.
124
+ Optional<TiledAndFusedLinalgOps>
125
+ tileAndFuseLinalgOps (PatternRewriter &rewriter, LinalgOp op,
126
+ const LinalgDependenceGraph &dependenceGraph,
127
+ const LinalgTilingOptions &tilingOptions,
128
+ const LinalgFusionOptions &fusionOptions);
129
+
56
130
// / Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
57
131
// / This is an in-place transformation controlled by `interchangeVector`.
58
132
// / An empty vector is interpreted as the identity permutation and the
@@ -323,6 +397,63 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
323
397
}
324
398
};
325
399
400
+ struct LinalgFusionOptions {
401
+ // / Optional list of operands indices to use for fusion. When unspecified,
402
+ // / only one fusion is done, i.e., the pattern returns after the first fusion.
403
+ Optional<DenseSet<unsigned >> indicesToFuse = None;
404
+ LinalgFusionOptions &setIndicesToFuse (ArrayRef<int64_t > operands) {
405
+ indicesToFuse = DenseSet<unsigned >();
406
+ indicesToFuse->insert (operands.begin (), operands.end ());
407
+ return *this ;
408
+ }
409
+ };
410
+
411
+ struct LinalgBaseTileAndFusePattern : public RewritePattern {
412
+ LinalgBaseTileAndFusePattern (StringRef opName, MLIRContext *context,
413
+ const LinalgDependenceGraph &dependenceGraph,
414
+ LinalgTilingOptions tilingOptions,
415
+ LinalgFusionOptions fusionOptions,
416
+ LinalgMarker marker = LinalgMarker(),
417
+ LinalgMarker fusedOpMarker = LinalgMarker(),
418
+ LinalgMarker originalOpMarker = LinalgMarker(),
419
+ PatternBenefit benefit = 1 );
420
+ LogicalResult matchAndRewrite (Operation *op,
421
+ PatternRewriter &rewriter) const override ;
422
+
423
+ private:
424
+ // / Dependence graph needed for fusion.
425
+ const LinalgDependenceGraph &dependenceGraph;
426
+ // / Options to control tiling.
427
+ LinalgTilingOptions tilingOptions;
428
+ // / Options to control fusion.
429
+ LinalgFusionOptions fusionOptions;
430
+ // / Marker to control application of the pattern.
431
+ LinalgMarker marker;
432
+ // / Marker set on the fused op after tile and fuse.
433
+ LinalgMarker fusedOpMarker;
434
+ // / The dependenceGraph is not modifiable, i.e. if the Linalg operations used
435
+ // / to build the dependence graph changes then the dependenceGraph needs to be
436
+ // / recomputed right now. To not invalidate the dependenceGraph as
437
+ // / transformation happens, the original producer can be tagged with a marker
438
+ // / that can be later used to delete the original operations.
439
+ LinalgMarker originalOpMarker;
440
+ };
441
+
442
+ template <typename OpTy>
443
+ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
444
+ LinalgTileAndFusePattern (MLIRContext *context,
445
+ const LinalgDependenceGraph &dependenceGraph,
446
+ LinalgTilingOptions tilingOptions,
447
+ LinalgFusionOptions fusionOptions,
448
+ LinalgMarker marker = LinalgMarker(),
449
+ LinalgMarker fusedOpMarker = LinalgMarker(),
450
+ LinalgMarker originalOpMarker = LinalgMarker(),
451
+ PatternBenefit benefit = 1 )
452
+ : LinalgBaseTileAndFusePattern(
453
+ OpTy::getOperationName (), context, dependenceGraph, tilingOptions,
454
+ fusionOptions, marker, fusedOpMarker, originalOpMarker, benefit) {}
455
+ };
456
+
326
457
// /
327
458
// / Linalg interchange patterns.
328
459
// /
0 commit comments