Skip to content

Commit 3efac5c

Browse files
authored
[MLIR][Linalg] Add pass to convert linalg.generic back to named ops (llvm#95656)
Add a new mlir-opt pass `--linalg-specialize-generic-ops` which lifts generic, where possible, to linalg named ops. Much like `-linalg-generalize-named-ops` lowers named ops to linalg.generic . Also add patterns to recognize contractions which can be specialized from linalg.generic to named op: `linalg.{batch_}?matmul{_transpose_(a|b)}?`
1 parent 69d3793 commit 3efac5c

File tree

9 files changed

+585
-2
lines changed

9 files changed

+585
-2
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
9494
let dependentDialects = ["linalg::LinalgDialect"];
9595
}
9696

97+
def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
98+
let summary = "Convert generic ops back to named ops";
99+
let dependentDialects = ["linalg::LinalgDialect"];
100+
}
101+
97102
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
98103
let summary = "Detensorize linalg ops";
99104
let dependentDialects = [];

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,6 +1416,20 @@ struct LinalgGeneralizationPattern
14161416
}
14171417
};
14181418

1419+
struct LinalgSpecializationPattern : public OpRewritePattern<GenericOp> {
1420+
using OpRewritePattern<GenericOp>::OpRewritePattern;
1421+
1422+
FailureOr<GenericOp>
1423+
returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const {
1424+
return specializeGenericOp(rewriter, op);
1425+
}
1426+
1427+
LogicalResult matchAndRewrite(GenericOp op,
1428+
PatternRewriter &rewriter) const override {
1429+
return returningMatchAndRewrite(op, rewriter);
1430+
}
1431+
};
1432+
14191433
/// Vectorization pattern for memref::CopyOp.
14201434
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
14211435
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
@@ -1567,6 +1581,15 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
15671581
/// linalg.generic ops.
15681582
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
15691583

1584+
/// Populates `patterns` with patterns to convert linalg.generic ops to named
1585+
/// ops where possible. A linalg.generic can represent wide range and complex
1586+
/// computations for which equivalent linalg named op may not exist e.g.
1587+
/// linalg.generic that takes a tensor and computes a polynomial such as:
1588+
/// p(x) = an*x^n + ... + a1x + a0
1589+
/// There is no equivalent named op to convert to. Many such cases exist.
1590+
void populateLinalgGenericOpsSpecializationPatterns(
1591+
RewritePatternSet &patterns);
1592+
15701593
/// Linalg decompose convolutions patterns
15711594

15721595
/// Populates patterns to decompose high-D convolution ops into low-D ones.

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
105105
static bool
106106
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
107107
unsigned arity) {
108-
// Check all loops are parallel, and have only tensor semantics.
108+
// Check all loops are parallel.
109109
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
110-
genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
110+
genericOp.getNumLoops() < 1)
111111
return false;
112112

113113
// Check there are arity-inputs, 1-output and all are identity-maps.

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,22 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "mlir/Dialect/Complex/IR/Complex.h"
1415
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1516
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
17+
#include "mlir/Dialect/Linalg/Passes.h"
1618
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1719
#include "mlir/Dialect/Math/IR/Math.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/Support/TypeID.h"
22+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1823
#include "llvm/Support/Debug.h"
1924

25+
namespace mlir {
26+
#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
27+
#include "mlir/Dialect/Linalg/Passes.h.inc"
28+
} // namespace mlir
29+
2030
#define DEBUG_TYPE "linalg-specialization"
2131

2232
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
@@ -58,6 +68,197 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
5868
return swapped;
5969
}
6070

71+
//===----------------------------------------------------------------------===//
72+
// Specialize linalg generic to matmul variants.
73+
//===----------------------------------------------------------------------===//
74+
/// Identifies linalg.generic that is essentially named op of the form:
75+
// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
76+
//
77+
// It is possible that a linalg.generic may be implementing a matmul but not
78+
// in a straight-forward way e.g. below is matrix multiply over some slice
79+
// ```
80+
// %0 = linalg.generic {
81+
// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
82+
// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
83+
// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
84+
// iterator_types = ["parallel", "parallel", "parallel"]}
85+
// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
86+
// outs(%C : tensor<20x20x20xf32>) {
87+
// ^bb0(%a: f32, %b: f32, %c : f32):
88+
// %mul = arith.mulf %a, %b : f32
89+
// %add = arith.addf %mul, %c : f32
90+
// linalg.yield %add : f32
91+
// } -> tensor<20x20x20xf32>
92+
// ```
93+
// It is not possible to represent above as named op.
94+
// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
95+
// not the same as linalg.generic above.
96+
namespace {
97+
enum class IndexMatchResult {
98+
Match = 0, // identity map.
99+
Transposed, // transposed map.
100+
Mismatch // none of the above.
101+
};
102+
103+
// Checks whether the input Affine `map` contains two consecutive dims that
104+
// can be interpreted as accessing a 2D matrix. It is assumed that the row
105+
// column dimension are adjacent axis (in this order) and start at
106+
// `rowDimIdx` in the input map.
107+
//
108+
// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
109+
// whether the map of A is identity (match), transposed, or something
110+
// completely different (mis-match). Similar for B and C.
111+
static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
112+
unsigned expectedPosOfRowDim,
113+
unsigned expectedPosOfColDim) {
114+
// Get the matrix multiply indices. They are past the batch indices.
115+
auto exprOfRowDim = map.getResults()[rowDimIdx];
116+
auto exprOfColDim = map.getResults()[rowDimIdx + 1];
117+
118+
// They should be pure dimension ids.
119+
if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
120+
exprOfColDim.getKind() != AffineExprKind::DimId)
121+
return IndexMatchResult::Mismatch;
122+
123+
auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
124+
auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
125+
126+
if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
127+
return IndexMatchResult::Match;
128+
129+
if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
130+
return IndexMatchResult::Transposed;
131+
132+
return IndexMatchResult::Mismatch;
133+
}
134+
135+
// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
136+
// All the variants expressed as pseudo regular expression:
137+
// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
138+
// have same number of ins/out, so its easy to stamp different versions.
139+
template <typename NamedOpTy>
140+
static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
141+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
142+
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
143+
ValueRange{op.getDpsInits()[0]});
144+
return namedOp;
145+
}
146+
147+
// Converts linalg.generic to named linalg.*matmul* where possible.
148+
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
149+
GenericOp genericOp) {
150+
if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
151+
return failure();
152+
153+
// Early exit if not projected permutations.
154+
auto mapRange = genericOp.getIndexingMapsArray();
155+
if (llvm::any_of(mapRange,
156+
[](AffineMap m) { return !m.isProjectedPermutation(); }))
157+
return failure();
158+
159+
// Linalg generic contraction can be across multiple axis e.g.
160+
// ```
161+
// linalg.generic
162+
// {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
163+
// affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
164+
// affine_map<(m, n, k1, k2) -> (m, n)>],
165+
// iterator_types = ["parallel", "parallel",
166+
// "reduction", "reduction"]}
167+
// ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
168+
// outs(%C : tensor<10x40xf32>) {
169+
// ^bb0(%a: f32, %b: f32, %c: f32):
170+
// %1 = arith.mulf %a, %b : f32
171+
// %2 = arith.addf %c, %1 : f32
172+
// linalg.yield %2 : f32
173+
// } -> tensor<10x40xf32>
174+
// ```
175+
// In above contraction, there are two reduction dimensions {k1, k2}
176+
// and although a valid linalg contraction, it is not a named-op
177+
// matrix multiply kind. Therefore, reject multi-dim reduction.
178+
auto res = inferContractionDims(genericOp);
179+
if (!succeeded(res))
180+
return failure();
181+
auto dims = *res;
182+
if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
183+
return failure();
184+
185+
if (!mlir::linalg::detail::isContractionBody(
186+
*genericOp.getBlock(), [](Operation *first, Operation *second) {
187+
if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
188+
(isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
189+
(isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
190+
return true;
191+
return false;
192+
}))
193+
return failure();
194+
195+
// Check rank of operands
196+
auto indexingMaps = genericOp.getIndexingMapsArray();
197+
if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
198+
return m.getResults().size() !=
199+
dims.batch.size() + 2 /* any two of {m,n,k} */;
200+
}))
201+
return failure();
202+
203+
auto numOfBatchDims = dims.batch.size();
204+
if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
205+
return failure();
206+
207+
if (numOfBatchDims) {
208+
// Each operand in a linalg generic contraction could express different
209+
// permutations for its batch dimension. But for named op it must be
210+
// identity since separate maps are not specified.
211+
if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
212+
for (unsigned i = 0; i < numOfBatchDims; ++i) {
213+
auto expr = m.getResults()[i];
214+
if (expr.getKind() != AffineExprKind::DimId ||
215+
cast<AffineDimExpr>(expr).getPosition() != i)
216+
return true;
217+
}
218+
return false;
219+
}))
220+
return failure();
221+
}
222+
223+
auto a =
224+
matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
225+
auto b =
226+
matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
227+
auto c =
228+
matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
229+
230+
if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
231+
return r == IndexMatchResult::Mismatch;
232+
}))
233+
return failure();
234+
235+
if (c != IndexMatchResult::Match ||
236+
(a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
237+
return failure();
238+
239+
/// Codegen the different matmul variants.
240+
if (numOfBatchDims) {
241+
if (a == IndexMatchResult::Transposed)
242+
return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
243+
genericOp);
244+
if (b == IndexMatchResult::Transposed)
245+
return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
246+
genericOp);
247+
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
248+
}
249+
250+
if (a == IndexMatchResult::Transposed)
251+
return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
252+
if (b == IndexMatchResult::Transposed)
253+
return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
254+
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
255+
}
256+
257+
} // namespace
258+
259+
//===----------------------------------------------------------------------===//
260+
// Categorize linalg generic to named op where possible.
261+
//===----------------------------------------------------------------------===//
61262
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
62263
GenericOp genericOp) {
63264
if (isaCopyOpInterface(genericOp)) {
@@ -100,5 +301,33 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
100301
return namedOp;
101302
}
102303
}
304+
305+
if (isaContractionOpInterface(genericOp)) {
306+
return specializeLinalgContractions(rewriter, genericOp);
307+
}
103308
return failure();
104309
}
310+
311+
namespace {
312+
struct LinalgSpecializeGenericOpsPass
313+
: public impl::LinalgSpecializeGenericOpsPassBase<
314+
LinalgSpecializeGenericOpsPass> {
315+
316+
using impl::LinalgSpecializeGenericOpsPassBase<
317+
LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
318+
void runOnOperation() override;
319+
};
320+
} // namespace
321+
322+
void LinalgSpecializeGenericOpsPass::runOnOperation() {
323+
RewritePatternSet patterns(&getContext());
324+
populateLinalgGenericOpsSpecializationPatterns(patterns);
325+
326+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
327+
signalPassFailure();
328+
}
329+
330+
void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
331+
RewritePatternSet &patterns) {
332+
patterns.add<LinalgSpecializationPattern>(patterns.getContext());
333+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// The following test examples of linalg named ops lowered to linalg.generic and then
2+
// lifted back up to named op.
3+
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
4+
5+
func.func @unary_exp(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
6+
linalg.exp ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
7+
return
8+
}
9+
10+
// CHECK-LABEL: unary_exp
11+
// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
12+
// CHECK-NOT: linalg.generic
13+
// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
14+
15+
// -----
16+
17+
func.func @binary_add(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
18+
%0 = linalg.add ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
19+
return %0 : tensor<?x?xf32>
20+
}
21+
22+
// CHECK-LABEL: binary_add
23+
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
24+
// CHECK-NOT: linalg.generic
25+
// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
26+
27+
// -----
28+
29+
func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
30+
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
31+
return %0 : tensor<?x?xf32>
32+
}
33+
34+
// CHECK-LABEL: @matmul
35+
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
36+
// CHECK-NOT: linalg.generic
37+
// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
38+
39+
// -----
40+
41+
func.func @mixed_named_ops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
42+
%C: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
43+
%AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
44+
%1 = linalg.add ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
45+
return %1 : tensor<?x?xf32>
46+
}
47+
48+
// CHECK-LABEL: @mixed_named_ops
49+
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
50+
// CHECK-NOT: linalg.generic
51+
// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
52+
// CHECK: linalg.add ins(%[[AB]], %[[C]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>

0 commit comments

Comments
 (0)