Skip to content

Commit f36a1f6

Browse files
committed
[mlir][linalg] Enable fuse consumer
This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp. - Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling interface which get iteration domain position from operand position. - Add interface method 'getTiledImplementationFromOperandPosition' to tiling interface which generate tiled implementation according to operand position. - Implemented the above two methods and supported consumer fusion for FuseIntoContainingOp.
1 parent 4616368 commit f36a1f6

File tree

8 files changed

+570
-69
lines changed

8 files changed

+570
-69
lines changed

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 90 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
2828
/*desc=*/[{
2929
Returns a list of iterator types that describe the number of loops.
3030
}],
31-
/*retType=*/"SmallVector<utils::IteratorType>",
31+
/*retType=*/"SmallVector<::mlir::utils::IteratorType>",
3232
/*methodName=*/"getLoopIteratorTypes",
3333
/*args=*/(ins),
3434
/*methodBody=*/"",
@@ -39,9 +39,9 @@ def TilingInterface : OpInterface<"TilingInterface"> {
3939
Returns a list of ranges that describe the loop bounds and
4040
step for the loops of the operation.
4141
}],
42-
/*retTy=*/"SmallVector<Range>",
42+
/*retTy=*/"SmallVector<::mlir::Range>",
4343
/*methodName=*/"getIterationDomain",
44-
/*args=*/(ins "OpBuilder &":$b),
44+
/*args=*/(ins "::mlir::OpBuilder &":$b),
4545
/*methodBody=*/"",
4646
/*defaultImplementation=*/"return {};"
4747
>,
@@ -63,12 +63,12 @@ def TilingInterface : OpInterface<"TilingInterface"> {
6363
The method returns the operation that is the tiled
6464
implementation.
6565
}],
66-
/*retType=*/"FailureOr<TilingResult>",
66+
/*retType=*/"FailureOr<::mlir::TilingResult>",
6767
/*methodName=*/"getTiledImplementation",
6868
/*args=*/(ins
69-
"OpBuilder &":$b,
70-
"ArrayRef<OpFoldResult> ":$offsets,
71-
"ArrayRef<OpFoldResult> ":$sizes),
69+
"::mlir::OpBuilder &":$b,
70+
"ArrayRef<::mlir::OpFoldResult> ":$offsets,
71+
"ArrayRef<::mlir::OpFoldResult> ":$sizes),
7272
/*methodBody=*/"",
7373
/*defaultImplementation=*/[{
7474
return {};
@@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
8282
by the tiled implementation. Expects the same `offsets` and `sizes` as
8383
used to obtain the tiled implementation of the operation.
8484
}],
85-
/*retType=*/"LogicalResult",
85+
/*retType=*/"::mlir::LogicalResult",
8686
/*methodName=*/"getResultTilePosition",
8787
/*args=*/(ins
88-
"OpBuilder &":$b,
88+
"::mlir::OpBuilder &":$b,
8989
"unsigned":$resultNumber,
90-
"ArrayRef<OpFoldResult> ":$offsets,
91-
"ArrayRef<OpFoldResult> ":$sizes,
92-
"SmallVector<OpFoldResult> &":$resultOffsets,
93-
"SmallVector<OpFoldResult> &":$resultSizes),
90+
"ArrayRef<::mlir::OpFoldResult> ":$offsets,
91+
"ArrayRef<::mlir::OpFoldResult> ":$sizes,
92+
"SmallVectorImpl<::mlir::OpFoldResult> &":$resultOffsets,
93+
"SmallVectorImpl<::mlir::OpFoldResult> &":$resultSizes),
94+
/*methodBody=*/"",
95+
/*defaultImplementation=*/[{
96+
return failure();
97+
}]
98+
>,
99+
InterfaceMethod<
100+
/*desc=*/[{
101+
Method to return the position of iteration domain tile computed by the
102+
tiled operation.
103+
}],
104+
/*retType=*/"::mlir::LogicalResult",
105+
/*methodName=*/"getIterationDomainTileFromOperandTile",
106+
/*args=*/(ins
107+
"::mlir::OpBuilder &":$b,
108+
"unsigned":$operandNumber,
109+
"ArrayRef<::mlir::OpFoldResult> ":$offsets,
110+
"ArrayRef<::mlir::OpFoldResult> ":$sizes,
111+
"SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
112+
"SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
94113
/*methodBody=*/"",
95114
/*defaultImplementation=*/[{
96115
return failure();
@@ -119,13 +138,49 @@ def TilingInterface : OpInterface<"TilingInterface"> {
119138
iteration space).
120139
- `sizes` provides the size of the tile.
121140
}],
122-
/*retType=*/"FailureOr<TilingResult>",
141+
/*retType=*/"FailureOr<::mlir::TilingResult>",
123142
/*methodName=*/"generateResultTileValue",
124143
/*args=*/(ins
125-
"OpBuilder &":$b,
144+
"::mlir::OpBuilder &":$b,
126145
"unsigned":$resultNumber,
127-
"ArrayRef<OpFoldResult>":$offsets,
128-
"ArrayRef<OpFoldResult>":$sizes),
146+
"ArrayRef<::mlir::OpFoldResult>":$offsets,
147+
"ArrayRef<::mlir::OpFoldResult>":$sizes),
148+
/*methodBody=*/"",
149+
/*defaultImplementation=*/[{
150+
return failure();
151+
}]
152+
>,
153+
InterfaceMethod<
154+
/*desc=*/[{
155+
Method to generate the tiled implementation of an operation from
156+
operand tile position.
157+
158+
Generates the IR that computes the tiled implementation of an
159+
operation from operand tile. The `offsets` and `sizes`
160+
describe the tile of the operand required. This is different from
161+
`getTiledImplementation` which generates the tiled
162+
implementation of the operation given a tile of the
163+
iteration space. This method generates a tiled
164+
implementation of the operation based on the tile of the
165+
operand required. This method enables consumer fusion by using
166+
tile and fuse. The method returns failure if the operation
167+
can't be tiled to generate the operand tile. In practical terms
168+
this implies it cannot be tiled and fused with its producers.
169+
170+
- `offsets` provides the offset of the tile in the coordinate system
171+
of the original iteration space, i.e., if an iteration space
172+
dimension had non-zero offset, it must be included in the offset
173+
provided here (as opposed to zero-based offset "relative" to the
174+
iteration space).
175+
- `sizes` provides the size of the tile.
176+
}],
177+
/*retType=*/"FailureOr<::mlir::TilingResult>",
178+
/*methodName=*/"getTiledImplementationFromOperandTile",
179+
/*args=*/(ins
180+
"::mlir::OpBuilder &":$b,
181+
"unsigned":$operandNumber,
182+
"ArrayRef<::mlir::OpFoldResult>":$offsets,
183+
"ArrayRef<::mlir::OpFoldResult>":$sizes),
129184
/*methodBody=*/"",
130185
/*defaultImplementation=*/[{
131186
return failure();
@@ -142,12 +197,12 @@ def TilingInterface : OpInterface<"TilingInterface"> {
142197
transformations are done, this method can be used to lower to scalar
143198
code that can then be lowered to LLVM or SPIR-V dialects.
144199
}],
145-
/*retType=*/"LogicalResult",
200+
/*retType=*/"::mlir::LogicalResult",
146201
/*methodName=*/"generateScalarImplementation",
147202
/*args=*/(ins
148-
"OpBuilder &":$b,
149-
"Location ":$loc,
150-
"ValueRange ":$ivs),
203+
"::mlir::OpBuilder &":$b,
204+
"::mlir::Location ":$loc,
205+
"::mlir::ValueRange ":$ivs),
151206
/*methodBody=*/"",
152207
/*defaultImplementation=*/[{
153208
return failure();
@@ -170,12 +225,12 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
170225
operation reduction. The tensor shape is equal to operation result
171226
shape with new dimension for each non zero tile size.
172227
}],
173-
/*retType=*/"FailureOr<Operation*>",
228+
/*retType=*/"FailureOr<::mlir::Operation*>",
174229
/*methodName=*/"generateInitialTensorForPartialReduction",
175230
/*args=*/(ins
176-
"OpBuilder &":$b,
177-
"Location ":$loc,
178-
"ArrayRef<OpFoldResult>":$sizes,
231+
"::mlir::OpBuilder &":$b,
232+
"::mlir::Location ":$loc,
233+
"ArrayRef<::mlir::OpFoldResult>":$sizes,
179234
"ArrayRef<int>":$reductionDim),
180235
/*methodBody=*/"",
181236
/*defaultImplementation=*/[{
@@ -189,14 +244,14 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
189244
less or equal to the tile size. This is meant to be used with
190245
`mergeReductions` method which will combine the partial reductions.
191246
}],
192-
/*retType=*/"Operation*",
247+
/*retType=*/"::mlir::Operation*",
193248
/*methodName=*/"tileToPartialReduction",
194249
/*args=*/(ins
195-
"OpBuilder &":$b,
196-
"Location ":$loc,
197-
"ValueRange":$init,
198-
"ArrayRef<OpFoldResult>":$offsets,
199-
"ArrayRef<OpFoldResult>":$sizes,
250+
"::mlir::OpBuilder &":$b,
251+
"::mlir::Location ":$loc,
252+
"::mlir::ValueRange":$init,
253+
"ArrayRef<::mlir::OpFoldResult>":$offsets,
254+
"ArrayRef<::mlir::OpFoldResult>":$sizes,
200255
"ArrayRef<int>":$reductionDims),
201256
/*methodBody=*/"",
202257
/*defaultImplementation=*/[{
@@ -209,12 +264,12 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
209264
tiled along the reduction dimensions. This will only apply the
210265
reduction the operation.
211266
}],
212-
/*retType=*/"Operation*",
267+
/*retType=*/"::mlir::Operation*",
213268
/*methodName=*/"mergeReductions",
214269
/*args=*/(ins
215-
"OpBuilder &":$b,
216-
"Location ":$loc,
217-
"ValueRange":$partialReduce,
270+
"::mlir::OpBuilder &":$b,
271+
"::mlir::Location ":$loc,
272+
"::mlir::ValueRange":$partialReduce,
218273
"ArrayRef<int>":$reductionDim),
219274
/*methodBody=*/"",
220275
/*defaultImplementation=*/[{

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
24252425

24262426
LogicalResult SoftmaxOp::getResultTilePosition(
24272427
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2428-
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2429-
SmallVector<OpFoldResult> &resultSizes) {
2428+
ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
2429+
SmallVectorImpl<OpFoldResult> &resultSizes) {
24302430
if (resultNumber == 0) {
24312431
resultOffsets.assign(offsets.begin(), offsets.end());
24322432
resultSizes.assign(sizes.begin(), sizes.end());

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
110110
}));
111111
}
112112

113-
// Instantiate the tiled implementation of the operation.
113+
/// Instantiate the tiled implementation of the operation.
114114
FailureOr<TilingResult>
115115
getTiledImplementation(Operation *op, OpBuilder &b,
116116
ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,65 @@ struct LinalgOpTilingInterface
132132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
133133
}
134134

135-
// Return the details of the output tile generated by the tiled
136-
// implementation.
135+
void
136+
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
137+
ArrayRef<OpFoldResult> offsets,
138+
ArrayRef<OpFoldResult> sizes,
139+
SmallVectorImpl<OpFoldResult> &mappedOffsets,
140+
SmallVectorImpl<OpFoldResult> &mappedSizes) const {
141+
unsigned numLoops = linalgOp.getNumLoops();
142+
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
143+
mappedOffsets.resize(numLoops);
144+
mappedSizes.resize(numLoops);
145+
if (!indexingMap.isPermutation()) {
146+
SmallVector<Range> iterationDomain =
147+
tilingInterfaceOp.getIterationDomain(b);
148+
for (auto &&[index, value] : llvm::enumerate(iterationDomain)) {
149+
mappedOffsets[index] = value.offset;
150+
mappedSizes[index] = value.size;
151+
}
152+
}
153+
for (auto &&[index, value] : llvm::enumerate(indexingMap.getResults())) {
154+
unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
155+
mappedOffsets[dimPosition] = offsets[index];
156+
mappedSizes[dimPosition] = sizes[index];
157+
}
158+
}
159+
160+
/// Return the details of the output tile generated by the tiled
161+
/// implementation.
162+
LogicalResult getIterationDomainTileFromOperandTile(
163+
Operation *op, OpBuilder &b, unsigned operandNumber,
164+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
165+
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
166+
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
167+
auto linalgOp = cast<LinalgOp>(op);
168+
169+
// Check that the indexing map used for the operand is a projected
170+
// permutation. This could be relaxed with a more general approach that can
171+
// map the offsets and sizes from the operand to iteration space tiles
172+
// (filling in full extent for dimensions not used to access the result).
173+
AffineMap indexingMap =
174+
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
175+
if (!indexingMap.isProjectedPermutation()) {
176+
return emitError(op->getLoc(),
177+
"unhandled get iter domain position when operand is not "
178+
"accessed using a permuted projection");
179+
}
180+
181+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
182+
iterDomainOffsets, iterDomainSizes);
183+
return success();
184+
}
185+
186+
/// Return the details of the output tile generated by the tiled
187+
/// implementation.
137188
LogicalResult
138189
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
139190
ArrayRef<OpFoldResult> offsets,
140191
ArrayRef<OpFoldResult> sizes,
141-
SmallVector<OpFoldResult> &resultOffsets,
142-
SmallVector<OpFoldResult> &resultSizes) const {
192+
SmallVectorImpl<OpFoldResult> &resultOffsets,
193+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
143194
Location loc = op->getLoc();
144195
LinalgOp linalgOp = cast<LinalgOp>(op);
145196

@@ -160,6 +211,21 @@ struct LinalgOpTilingInterface
160211
return success();
161212
}
162213

214+
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
215+
Operation *op, OpBuilder &b, unsigned operandNumber,
216+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
217+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
218+
auto tilingInterfaceOp = cast<TilingInterface>(op);
219+
if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
220+
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
221+
return emitError(
222+
op->getLoc(),
223+
"unable to obtain the iter domain position of the operation.");
224+
}
225+
return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
226+
mappedSizes);
227+
}
228+
163229
FailureOr<TilingResult>
164230
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
165231
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +243,16 @@ struct LinalgOpTilingInterface
177243
"unhandled tiled implementation generation when result is not "
178244
"accessed using a permuted projection");
179245
}
180-
181-
auto numLoops = linalgOp.getNumLoops();
246+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
247+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
248+
mappedOffsets, mappedSizes);
182249
auto tilingInterfaceOp = cast<TilingInterface>(op);
183-
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
184-
iterationTileSizes(numLoops);
185-
if (!indexingMap.isPermutation()) {
186-
SmallVector<Range> iterationDomain =
187-
tilingInterfaceOp.getIterationDomain(b);
188-
for (const auto &range : llvm::enumerate(iterationDomain)) {
189-
iterationTileOffsets[range.index()] = range.value().offset;
190-
iterationTileSizes[range.index()] = range.value().size;
191-
}
192-
}
193-
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
194-
unsigned dimPosition =
195-
cast<AffineDimExpr>(resultExpr.value()).getPosition();
196-
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
197-
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
198-
}
199-
200250
FailureOr<TilingResult> tilingResult =
201-
tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
202-
iterationTileSizes);
251+
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
252+
253+
if (failed(tilingResult))
254+
return failure();
255+
203256
if (tilingResult->tiledOps.size() != 1)
204257
return op->emitOpError("failed to generate tiled implementation");
205258

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
6161
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
6262
ArrayRef<OpFoldResult> offsets,
6363
ArrayRef<OpFoldResult> sizes,
64-
SmallVector<OpFoldResult> &resultOffsets,
65-
SmallVector<OpFoldResult> &resultSizes) const {
64+
SmallVectorImpl<OpFoldResult> &resultOffsets,
65+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
6666
resultOffsets.assign(offsets.begin(), offsets.end());
6767
resultSizes.assign(sizes.begin(), sizes.end());
6868
return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
199199
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
200200
ArrayRef<OpFoldResult> offsets,
201201
ArrayRef<OpFoldResult> sizes,
202-
SmallVector<OpFoldResult> &resultOffsets,
203-
SmallVector<OpFoldResult> &resultSizes) const {
202+
SmallVectorImpl<OpFoldResult> &resultOffsets,
203+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
204204
// The iteration domain is over outer dimensions of packed layout. In this
205205
// context, the outer dimensions of `resultOffsets` are `offsets`. The
206206
// inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
452452
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
453453
ArrayRef<OpFoldResult> offsets,
454454
ArrayRef<OpFoldResult> sizes,
455-
SmallVector<OpFoldResult> &resultOffsets,
456-
SmallVector<OpFoldResult> &resultSizes) const {
455+
SmallVectorImpl<OpFoldResult> &resultOffsets,
456+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
457457
resultOffsets = llvm::to_vector(offsets);
458458
resultSizes = llvm::to_vector(sizes);
459459
return success();

0 commit comments

Comments
 (0)