Skip to content

Commit c13f806

Browse files
authored
[mlir][linalg] raise generic to named ops. (#110421)
Add support for specializing linalg.broadcast and linalg.transform from generic. Also, does some refactoring to reuse specialization checks, migrating some common uses to op interface methods.
1 parent bae17a2 commit c13f806

File tree

9 files changed

+260
-54
lines changed

9 files changed

+260
-54
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp,
120120
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
121121
bool isaCopyOpInterface(LinalgOp linalgOp);
122122

123+
/// Checks whether `genericOp` is semantically equivalent to a
124+
/// `linalg.broadcast`. Returns broadcast dimensions if true.
125+
std::optional<SmallVector<int64_t>>
126+
isaBroadcastOpInterface(GenericOp genericOp);
127+
128+
/// Checks whether `genericOp` is semantically equivalent to a
129+
/// `linalg.transpose`. Returns permuted dimensions if true.
130+
std::optional<SmallVector<int64_t>>
131+
isaTransposeOpInterface(GenericOp genericOp);
132+
123133
/// Checks whether a given `genericOp` is semantically equivalent to a single
124134
/// linalgelementwise unary op. e.g. linalg.exp.
125135
/// A linalg.generic body could be a series of unary elementwise ops e.g.

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,18 @@ def LinalgStructuredInterface
243243
utils::IteratorType::parallel);
244244
}]
245245
>,
246+
InterfaceMethod<
247+
/*desc=*/[{
248+
Return true if all loops are parallel.
249+
}],
250+
/*retTy=*/"bool",
251+
/*methodName=*/"isAllParallelLoops",
252+
/*args=*/(ins),
253+
/*methodBody=*/"",
254+
/*defaultImplementation=*/[{
255+
return getNumParallelLoops() == getNumParallelLoops();
256+
}]
257+
>,
246258
InterfaceMethod<
247259
/*desc=*/[{
248260
Return the dims that are parallel loops.
@@ -327,6 +339,18 @@ def LinalgStructuredInterface
327339
return !getBlock()->getArgument(bbArgNumber).use_empty();
328340
}]
329341
>,
342+
InterfaceMethod<
343+
/*desc=*/[{
344+
Returns true only if linalgOp takes one input and produces one result.
345+
}],
346+
/*retTy=*/"bool",
347+
/*methodName=*/"isSingleInputOutput",
348+
/*args=*/(ins),
349+
/*methodBody=*/"",
350+
/*defaultImplementation=*/[{
351+
return $_op.getNumDpsInputs() == 1 && $_op.getNumDpsInits() == 1;
352+
}]
353+
>,
330354
InterfaceMethod<
331355
/*desc=*/[{
332356
Return true if `opOperand` is an init tensor. This is true when it is

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,24 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
210210
}
211211

212212
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
213+
214+
// Return true only if GenericOp has a single input and single
215+
// output, and the body is a single yieldOp that yields the input.
216+
// This check is useful when trying to determine if the op is
217+
// essentially a transpose, broadcast, copy or something like that.
218+
bool isSingleYieldOp() {
219+
if (!isSingleInputOutput())
220+
return false;
221+
Block *body = getBody();
222+
if (body->getOperations().size() != 1)
223+
return false;
224+
225+
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
226+
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
227+
yieldOp->getOperand(0) != body->getArgument(0))
228+
return false;
229+
return true;
230+
}
213231
}];
214232

215233
let hasCanonicalizer = 1;

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 111 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/ADT/SmallBitVector.h"
2323
#include "llvm/ADT/SmallVector.h"
2424
#include <algorithm>
25+
#include <numeric>
2526

2627
using namespace mlir;
2728
using namespace mlir::linalg;
@@ -53,112 +54,180 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
5354
// CopyOpInterface implementation
5455
//===----------------------------------------------------------------------===//
5556

56-
bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
57-
// Structural.
58-
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
57+
bool linalg::isaCopyOpInterface(LinalgOp op) {
58+
// Check all loops are parallel and linalgOp is single input and output.
59+
if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
5960
return false;
6061

61-
// Operands and maps.
62-
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
63-
return false;
64-
auto mapRange = linalgOp.getIndexingMapsArray();
62+
auto mapRange = op.getIndexingMapsArray();
6563
if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
6664
!mapRange.back().isIdentity()) {
6765
return false;
6866
}
6967
// Region.
70-
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
68+
return llvm::hasSingleElement(op.getBlock()->getOperations());
7169
}
7270

7371
//===----------------------------------------------------------------------===//
7472
// FillOpInterface implementation
7573
//===----------------------------------------------------------------------===//
76-
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
74+
std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
7775
// Structural.
78-
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
79-
genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
76+
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
77+
!op.isSingleYieldOp())
8078
return std::nullopt;
8179

8280
// Input should be referenced and init should not.
83-
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
84-
genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
81+
if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
82+
op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
8583
return std::nullopt;
8684

87-
OpOperand *value = genericOp.getDpsInputOperand(0);
88-
if (!genericOp.isScalar(value))
85+
OpOperand *value = op.getDpsInputOperand(0);
86+
if (!op.isScalar(value))
8987
return std::nullopt;
88+
return value->get();
89+
}
9090

91-
Block *body = genericOp.getBody();
92-
if (body->getOperations().size() != 1)
91+
//===----------------------------------------------------------------------===//
92+
// BroadcastOpInterface implementation
93+
//===----------------------------------------------------------------------===//
94+
std::optional<SmallVector<int64_t>>
95+
linalg::isaBroadcastOpInterface(GenericOp op) {
96+
// Structural.
97+
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
98+
!op.isSingleYieldOp())
9399
return std::nullopt;
94100

95-
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
96-
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
97-
yieldOp->getOperand(0) != body->getArgument(0))
101+
auto srcTy = op.getDpsInputOperand(0)->get().getType();
102+
auto dstTy = op.getDpsInitOperand(0)->get().getType();
103+
if (!isa<MemRefType, RankedTensorType>(srcTy) ||
104+
!isa<MemRefType, RankedTensorType>(dstTy))
98105
return std::nullopt;
99-
return value->get();
106+
107+
// Check output is identity map. Broadcast could additionally be
108+
// employing permutation of indices and that would be expressible
109+
// in linalg.generic but is not expressible for named broadcast op.
110+
auto dstMap = op.getIndexingMapsArray()[1];
111+
if (!dstMap.isIdentity())
112+
return std::nullopt;
113+
114+
SmallVector<int64_t> position;
115+
auto srcMap = op.getIndexingMapsArray()[0];
116+
117+
if (srcMap.getResults().size() >= dstMap.getResults().size())
118+
return std::nullopt;
119+
120+
// Check input map is monotonically increasing DimIds.
121+
for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
122+
auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
123+
if (!expr)
124+
return std::nullopt;
125+
int64_t pos = expr.getPosition();
126+
if (i > 0 && pos <= position[i - 1])
127+
return std::nullopt;
128+
position.push_back(expr.getPosition());
129+
}
130+
131+
SmallVector<int64_t> broadcastedDims;
132+
auto numDims = srcMap.getNumDims();
133+
// This is quadratic but number of items is generally small.
134+
for (auto dim : llvm::seq<int64_t>(0, numDims)) {
135+
if (!llvm::is_contained(position, dim))
136+
broadcastedDims.push_back(dim);
137+
}
138+
return broadcastedDims;
139+
}
140+
141+
//===----------------------------------------------------------------------===//
142+
// TranposeOpInterface implementation
143+
//===----------------------------------------------------------------------===//
144+
std::optional<SmallVector<int64_t>>
145+
linalg::isaTransposeOpInterface(GenericOp op) {
146+
// To specialize as a transpose op, the genericOp must be
147+
// all parallel loops, single input, single output, and its body
148+
// should be just a yield op, yielding input as output as is (no compute).
149+
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
150+
!op.isSingleYieldOp())
151+
return std::nullopt;
152+
153+
auto mapRange = op.getIndexingMapsArray();
154+
if (mapRange.size() != 2)
155+
return std::nullopt;
156+
157+
auto mapOfInput = mapRange.front();
158+
auto mapOfResult = mapRange.back();
159+
160+
// linalg.transpose permutes the dimensions of input using this
161+
// rule: dim(result, i) = dim(input, permutation[i])
162+
if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
163+
return std::nullopt;
164+
165+
SmallVector<int64_t> permutation(mapOfInput.getNumDims());
166+
for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
167+
auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
168+
permutation[expr.getPosition()] = i;
169+
}
170+
return permutation;
100171
}
101172

102173
//===----------------------------------------------------------------------===//
103174
// Elementwise Single Unary/Binary-OpInterface implementation
104175
//===----------------------------------------------------------------------===//
105-
static bool
106-
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
107-
unsigned arity) {
176+
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
177+
unsigned arity) {
108178
// Check all loops are parallel.
109-
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
110-
genericOp.getNumLoops() < 1)
179+
if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
111180
return false;
112181

113182
// Check there are arity-inputs, 1-output and all are identity-maps.
114-
if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
115-
!llvm::all_of(genericOp.getIndexingMapsArray(),
183+
if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
184+
!llvm::all_of(op.getIndexingMapsArray(),
116185
[](AffineMap map) { return map.isIdentity(); }))
117186
return false;
118187

119188
// Init should not be referenced for elementwise operations.
120-
if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
189+
if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
121190
return false;
122191

123192
// A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
124193
// as resulting from producer-consumer fusion. Here, we restrict to two ops in
125194
// the body, where the first is the elementwise single op and the second a
126195
// yield.
127-
Block *body = genericOp.getBody();
196+
Block *body = op.getBody();
128197
if (body->getOperations().size() != 2)
129198
return false;
130199

131-
Operation *op = &body->front();
132-
if (op->getNumOperands() != arity || op->getNumResults() != 1)
200+
Operation *oper = &body->front();
201+
if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
133202
return false;
134203

135204
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
136205
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
137-
yieldOp->getOperand(0).getDefiningOp() != op)
206+
yieldOp->getOperand(0).getDefiningOp() != oper)
138207
return false;
139208
return true;
140209
}
141210

142-
bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
211+
bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
143212
// All basic elemwise checks.
144-
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1))
213+
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1))
145214
return false;
146215

147216
// Check input is actully used.
148-
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
217+
if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
149218
return false;
150219
return true;
151220
}
152221

153-
bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
154-
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2))
222+
bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
223+
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2))
155224
return false;
156225

157226
// Check both inputs are used (elementwise).
158-
OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
159-
OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
160-
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
161-
!genericOp.payloadUsesValueFromOperand(inputOpOperand1))
227+
OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
228+
OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
229+
if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
230+
!op.payloadUsesValueFromOperand(inputOpOperand1))
162231
return false;
163232
return true;
164233
}

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,18 +259,43 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
259259
//===----------------------------------------------------------------------===//
260260
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
261261
GenericOp genericOp) {
262+
// Copy
262263
if (isaCopyOpInterface(genericOp)) {
263264
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
264265
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
265266
return namedOp;
266267
}
267268

269+
// Fill
268270
if (isaFillOpInterface(genericOp)) {
269271
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
270272
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
271273
return namedOp;
272274
}
273275

276+
// Broadcast
277+
std::optional<SmallVector<int64_t>> equivalentToBroadcast =
278+
isaBroadcastOpInterface(genericOp);
279+
if (equivalentToBroadcast) {
280+
auto dims = *equivalentToBroadcast;
281+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
282+
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
283+
dims);
284+
return namedOp;
285+
}
286+
287+
// Transpose
288+
std::optional<SmallVector<int64_t>> equivalentToTranspose =
289+
isaTransposeOpInterface(genericOp);
290+
if (equivalentToTranspose) {
291+
auto permutation = *equivalentToTranspose;
292+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
293+
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
294+
permutation);
295+
return namedOp;
296+
}
297+
298+
// Elementwise Unary
274299
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
275300
Operation *op = &genericOp.getBody()->front();
276301
if (isa<math::ExpOp>(op)) {
@@ -279,6 +304,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
279304
}
280305
}
281306

307+
// Elementwise Binary
282308
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
283309
bool swap = areBinOpsSwapped(genericOp);
284310
Operation *op = &genericOp.getBody()->front();
@@ -300,6 +326,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
300326
}
301327
}
302328

329+
// Contraction - e.g. matmul
303330
if (isaContractionOpInterface(genericOp)) {
304331
return specializeLinalgContractions(rewriter, genericOp);
305332
}

0 commit comments

Comments
 (0)