Skip to content

Commit b7a0a85

Browse files
cxy-1993chenxunyu
authored andcommitted
[mlir][linalg] Enable fuse consumer
This patch adds support for consumer fusion to the tiling interface. - Add interface method 'getIterationDomainTileFromOperandTile' to tiling interface which get iteration domain position from operand position. - Add interface method 'getTiledImplementationFromOperandTile' to tiling interface which generate tiled implementation according to operand position.
1 parent 52a1998 commit b7a0a85

File tree

2 files changed

+129
-21
lines changed

2 files changed

+129
-21
lines changed

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
9696
return failure();
9797
}]
9898
>,
99+
InterfaceMethod<
100+
/*desc=*/[{
101+
Method to return the position of iteration domain tile computed by the
102+
tiled operation.
103+
}],
104+
/*retType=*/"LogicalResult",
105+
/*methodName=*/"getIterationDomainTileFromOperandTile",
106+
/*args=*/(ins
107+
"OpBuilder &":$b,
108+
"unsigned":$operandNumber,
109+
"ArrayRef<OpFoldResult> ":$offsets,
110+
"ArrayRef<OpFoldResult> ":$sizes,
111+
"SmallVector<OpFoldResult> &":$iterDomainOffsets,
112+
"SmallVector<OpFoldResult> &":$iterDomainSizes),
113+
/*methodBody=*/"",
114+
/*defaultImplementation=*/[{
115+
return failure();
116+
}]
117+
>,
99118
InterfaceMethod<
100119
/*desc=*/[{
101120
Method to generate the code that produces a tile of the result.
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
131150
return failure();
132151
}]
133152
>,
153+
InterfaceMethod<
154+
/*desc=*/[{
155+
Method to generate the tiled implementation of an operation from
156+
operand tile position.
157+
158+
Generates the IR that computes the tiled implementation of an
159+
operation from operand tile. The `offsets` and `sizes`
160+
describe the tile of the operand required. This is different from
161+
`getTiledImplementation` which generates the tiled
162+
implementation of the operation given a tile of the
163+
iteration space. This method generates a tiled
164+
implementation of the operation based on the tile of the
165+
operand required. This method enables fusion consumer by using
166+
tile and fuse. The method returns failure if the operation
167+
can't be tiled to generate the operand tile. In practical terms
168+
this implies it cannot be tiled and fused with its producers.
169+
170+
- `offsets` provides the offset of the tile in the coordinate system
171+
of the original iteration space, i.e., if an iteration space
172+
dimension had non-zero offset, it must be included in the offset
173+
provided here (as opposed to zero-based offset "relative" to the
174+
iteration space).
175+
- `sizes` provides the size of the tile.
176+
}],
177+
/*retType=*/"FailureOr<TilingResult>",
178+
/*methodName=*/"getTiledImplementationFromOperandTile",
179+
/*args=*/(ins
180+
"OpBuilder &":$b,
181+
"unsigned":$operandNumber,
182+
"ArrayRef<OpFoldResult>":$offsets,
183+
"ArrayRef<OpFoldResult>":$sizes),
184+
/*methodBody=*/"",
185+
/*defaultImplementation=*/[{
186+
return failure();
187+
}]
188+
>,
134189
InterfaceMethod<
135190
/*desc=*/[{
136191
Generates the scalar implementation of the operation.

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,58 @@ struct LinalgOpTilingInterface
132132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
133133
}
134134

135+
void getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
136+
AffineMap indexingMap,
137+
ArrayRef<OpFoldResult> offsets,
138+
ArrayRef<OpFoldResult> sizes,
139+
SmallVector<OpFoldResult> &mappedOffsets,
140+
SmallVector<OpFoldResult> &mappedSizes) const {
141+
auto numLoops = linalgOp.getNumLoops();
142+
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
143+
mappedOffsets.resize(numLoops);
144+
mappedSizes.resize(numLoops);
145+
if (!indexingMap.isPermutation()) {
146+
SmallVector<Range> iterationDomain =
147+
tilingInterfaceOp.getIterationDomain(b);
148+
for (const auto &range : llvm::enumerate(iterationDomain)) {
149+
mappedOffsets[range.index()] = range.value().offset;
150+
mappedSizes[range.index()] = range.value().size;
151+
}
152+
}
153+
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
154+
unsigned dimPosition =
155+
cast<AffineDimExpr>(resultExpr.value()).getPosition();
156+
mappedOffsets[dimPosition] = offsets[resultExpr.index()];
157+
mappedSizes[dimPosition] = sizes[resultExpr.index()];
158+
}
159+
}
160+
161+
// Return the details of the output tile generated by the tiled
162+
// implementation.
163+
LogicalResult getIterationDomainTileFromOperandTile(
164+
Operation *op, OpBuilder &b, unsigned operandNumber,
165+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
166+
SmallVector<OpFoldResult> &iterDomainOffsets,
167+
SmallVector<OpFoldResult> &iterDomainSizes) const {
168+
auto linalgOp = cast<LinalgOp>(op);
169+
170+
// Check that the indexing map used for the operand is a projected
171+
// permutation. This could be relaxed with a more general approach that can
172+
// map the offsets and sizes from the operand to iteration space tiles
173+
// (filling in full extent for dimensions not used to access the result).
174+
AffineMap indexingMap =
175+
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
176+
if (!indexingMap.isProjectedPermutation()) {
177+
return op->emitOpError(
178+
"unhandled get iter domain position when operand is not "
179+
"accessed using a permuted projection");
180+
}
181+
182+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
183+
iterDomainOffsets, iterDomainSizes);
184+
return success();
185+
}
186+
135187
// Return the details of the output tile generated by the tiled
136188
// implementation.
137189
LogicalResult
@@ -160,6 +212,20 @@ struct LinalgOpTilingInterface
160212
return success();
161213
}
162214

215+
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
216+
Operation *op, OpBuilder &b, unsigned operandNumber,
217+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
218+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
219+
auto tilingInterfaceOp = cast<TilingInterface>(op);
220+
if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
221+
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
222+
return op->emitOpError(
223+
"unable to obtain the iter domain position of the operation.");
224+
}
225+
return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
226+
mappedSizes);
227+
}
228+
163229
FailureOr<TilingResult>
164230
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
165231
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +243,16 @@ struct LinalgOpTilingInterface
177243
"unhandled tiled implementation generation when result is not "
178244
"accessed using a permuted projection");
179245
}
180-
181-
auto numLoops = linalgOp.getNumLoops();
246+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
247+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
248+
mappedOffsets, mappedSizes);
182249
auto tilingInterfaceOp = cast<TilingInterface>(op);
183-
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
184-
iterationTileSizes(numLoops);
185-
if (!indexingMap.isPermutation()) {
186-
SmallVector<Range> iterationDomain =
187-
tilingInterfaceOp.getIterationDomain(b);
188-
for (const auto &range : llvm::enumerate(iterationDomain)) {
189-
iterationTileOffsets[range.index()] = range.value().offset;
190-
iterationTileSizes[range.index()] = range.value().size;
191-
}
192-
}
193-
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
194-
unsigned dimPosition =
195-
cast<AffineDimExpr>(resultExpr.value()).getPosition();
196-
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
197-
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
198-
}
199-
200250
FailureOr<TilingResult> tilingResult =
201-
tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
202-
iterationTileSizes);
251+
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
252+
253+
if (failed(tilingResult))
254+
return failure();
255+
203256
if (tilingResult->tiledOps.size() != 1)
204257
return op->emitOpError("failed to generate tiled implementation");
205258

0 commit comments

Comments
 (0)