Skip to content

Commit c694588

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Add pattern to tile and fuse Linalg operations on buffers.
The pattern is structured similar to other patterns like LinalgTilingPattern. The fusion patterns takes options that allows you to fuse with producers of multiple operands at once. - The pattern fuses only at the level that is known to be legal, i.e if a reduction loop in the consumer is tiled, then fusion should happen "before" this loop. Some refactoring of the fusion code is needed to fuse only where it is legal. - Since the fusion on buffers uses the LinalgDependenceGraph that is not mutable in place the fusion pattern keeps the original operations in the IR, but are tagged with a marker that can be later used to find the original operations. This change also fixes an issue with tiling and distribution/interchange where if the tile size of a loop were 0 it wasnt account for in these. Differential Revision: https://reviews.llvm.org/D88435
1 parent 8d250ac commit c694588

File tree

10 files changed

+1033
-106
lines changed

10 files changed

+1033
-106
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,24 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
459459
}));
460460
}]
461461
>,
462+
InterfaceMethod<
463+
/*desc=*/[{
464+
Return the position of buffer in inputs + outputs list
465+
}],
466+
/*retTy=*/"Optional<unsigned>",
467+
/*methodName=*/"getIndexOfInputAndOutputBuffer",
468+
/*args=*/(ins "Value":$value),
469+
/*methodBody=*/"",
470+
/*defaultImplementation=*/[{
471+
Optional<unsigned> inputIndex = getIndexOfInput(value);
472+
if (inputIndex.hasValue()) return inputIndex.getValue();
473+
Optional<unsigned> outputIndex = getIndexOfOutputBuffer(value);
474+
if (outputIndex.hasValue()) {
475+
return $_op.getNumInputs() + outputIndex.getValue();
476+
}
477+
return llvm::None;
478+
}]
479+
>,
462480

463481
//===------------------------------------------------------------------===//
464482
// Other interface methods.

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
namespace mlir {
1919
namespace linalg {
2020

21+
struct LinalgFusionOptions;
2122
struct LinalgTilingOptions;
2223

2324
//===----------------------------------------------------------------------===//
@@ -30,6 +31,14 @@ struct TiledLinalgOp {
3031
SmallVector<Operation *, 8> loops;
3132
};
3233

34+
struct TiledAndFusedLinalgOps {
35+
LinalgOp op;
36+
SmallVector<LinalgOp, 1> fusedProducers;
37+
SmallVector<LinalgOp, 1> originalProducers;
38+
SmallVector<Operation *, 4> fusedLoops;
39+
SmallVector<Operation *, 4> unfusedLoops;
40+
};
41+
3342
/// Populates patterns for vectorization of all ConvN-D ops.
3443
void populateConvVectorizationPatterns(
3544
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@@ -53,6 +62,71 @@ void populateConvVectorizationPatterns(
5362
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
5463
const LinalgTilingOptions &options);
5564

65+
/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in
66+
/// three steps
67+
/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile
68+
/// + fuse loops).
69+
/// - Tile just these loops of the consumer (root operation) and fuse with
70+
/// the producer.
71+
/// - Tile again the tiled consumer operation produced above to do rest of
72+
/// the tiling specified by the `tilingOptions`.
73+
///
74+
/// For example, consider the sequence of matmul below
75+
///
76+
/// linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>)
77+
/// outs(%arg2 : memref<256x32xf32>)
78+
/// linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>)
79+
/// outs(%arg4 : memref<256x32xf32>)
80+
///
81+
/// It is legal to fuse the RAW dependence (through %arg2) by only fusing the
82+
/// matmuls row-wise. For example, the fused computation for the above is shown
83+
/// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling
84+
/// along the rows of the matrix. The entire rows of the first matmul operation
85+
/// need to be computed before they can be used for the second matmul. The
86+
/// second matmul is further tiled (similar to normal tiling).
87+
///
88+
/// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
89+
/// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
90+
/// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) {
91+
/// %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1]
92+
/// : memref<256x32xf32> to memref<16x32xf32, #map0>
93+
/// %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1]
94+
/// : memref<256x32xf32> to memref<16x32xf32, #map0>
95+
/// %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1]
96+
/// : memref<256x32xf32> to memref<16x32xf32, #map0>
97+
/// %3 = subview %arg1[0, 0] [32, 32] [1, 1]
98+
/// : memref<32x32xf32> to memref<32x32xf32, #map1>
99+
/// linalg.matmul
100+
/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
101+
/// outs(%0 : memref<16x32xf32, #map0>)
102+
/// scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) {
103+
/// scf.for %arg7 = %c0 to %c32 step %c4 {
104+
/// %4 = subview %0[0, %arg7] [16, 4] [1, 1]
105+
/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0>
106+
/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1]
107+
/// : memref<32x32xf32> to memref<4x8xf32, #map0>
108+
/// %6 = subview %1[0, %arg6] [16, 8] [1, 1]
109+
/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0>
110+
/// linalg.matmul
111+
/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
112+
/// outs(%6 : memref<16x8xf32, #map0>)
113+
/// }
114+
/// scf.yield
115+
/// }
116+
/// scf.yield
117+
/// }
118+
///
119+
/// The following tiling options are handled differently in tile+fuse (compared
120+
/// to tile only)
121+
/// - Interchange of the tiling loops is not supported right now.
122+
/// - Distribution is only done for the tile+fuse loops. The tiled loops
123+
/// generated by the second tiling is not distributed.
124+
Optional<TiledAndFusedLinalgOps>
125+
tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
126+
const LinalgDependenceGraph &dependenceGraph,
127+
const LinalgTilingOptions &tilingOptions,
128+
const LinalgFusionOptions &fusionOptions);
129+
56130
/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
57131
/// This is an in-place transformation controlled by `interchangeVector`.
58132
/// An empty vector is interpreted as the identity permutation and the
@@ -323,6 +397,63 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
323397
}
324398
};
325399

400+
struct LinalgFusionOptions {
401+
/// Optional list of operands indices to use for fusion. When unspecified,
402+
/// only one fusion is done, i.e., the pattern returns after the first fusion.
403+
Optional<DenseSet<unsigned>> indicesToFuse = None;
404+
LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) {
405+
indicesToFuse = DenseSet<unsigned>();
406+
indicesToFuse->insert(operands.begin(), operands.end());
407+
return *this;
408+
}
409+
};
410+
411+
struct LinalgBaseTileAndFusePattern : public RewritePattern {
412+
LinalgBaseTileAndFusePattern(StringRef opName, MLIRContext *context,
413+
const LinalgDependenceGraph &dependenceGraph,
414+
LinalgTilingOptions tilingOptions,
415+
LinalgFusionOptions fusionOptions,
416+
LinalgMarker marker = LinalgMarker(),
417+
LinalgMarker fusedOpMarker = LinalgMarker(),
418+
LinalgMarker originalOpMarker = LinalgMarker(),
419+
PatternBenefit benefit = 1);
420+
LogicalResult matchAndRewrite(Operation *op,
421+
PatternRewriter &rewriter) const override;
422+
423+
private:
424+
/// Dependence graph needed for fusion.
425+
const LinalgDependenceGraph &dependenceGraph;
426+
/// Options to control tiling.
427+
LinalgTilingOptions tilingOptions;
428+
/// Options to control fusion.
429+
LinalgFusionOptions fusionOptions;
430+
/// Marker to control application of the pattern.
431+
LinalgMarker marker;
432+
/// Marker set on the fused op after tile and fuse.
433+
LinalgMarker fusedOpMarker;
434+
/// The dependenceGraph is not modifiable, i.e. if the Linalg operations used
435+
/// to build the dependence graph changes then the dependenceGraph needs to be
436+
/// recomputed right now. To not invalidate the dependenceGraph as
437+
/// transformation happens, the original producer can be tagged with a marker
438+
/// that can be later used to delete the original operations.
439+
LinalgMarker originalOpMarker;
440+
};
441+
442+
template <typename OpTy>
443+
struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
444+
LinalgTileAndFusePattern(MLIRContext *context,
445+
const LinalgDependenceGraph &dependenceGraph,
446+
LinalgTilingOptions tilingOptions,
447+
LinalgFusionOptions fusionOptions,
448+
LinalgMarker marker = LinalgMarker(),
449+
LinalgMarker fusedOpMarker = LinalgMarker(),
450+
LinalgMarker originalOpMarker = LinalgMarker(),
451+
PatternBenefit benefit = 1)
452+
: LinalgBaseTileAndFusePattern(
453+
OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
454+
fusionOptions, marker, fusedOpMarker, originalOpMarker, benefit) {}
455+
};
456+
326457
///
327458
/// Linalg interchange patterns.
328459
///

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_LINALG_UTILS_H_
1111

1212
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
13+
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
1314
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
1415
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
1516
#include "mlir/Dialect/SCF/SCF.h"

0 commit comments

Comments
 (0)