Skip to content

Commit 1cc3b31

Browse files
authored
Revert "[MLIR][Linalg] Enable fuse consumer (#85528)"
This reverts commit 2a47ee0.
1 parent da57609 commit 1cc3b31

File tree

4 files changed

+40
-149
lines changed

4 files changed

+40
-149
lines changed

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 6 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
6363
The method returns the operation that is the tiled
6464
implementation.
6565
}],
66-
/*retType=*/"FailureOr<::mlir::TilingResult>",
66+
/*retType=*/"FailureOr<TilingResult>",
6767
/*methodName=*/"getTiledImplementation",
6868
/*args=*/(ins
6969
"OpBuilder &":$b,
@@ -82,34 +82,15 @@ def TilingInterface : OpInterface<"TilingInterface"> {
8282
by the tiled implementation. Expects the same `offsets` and `sizes` as
8383
used to obtain the tiled implementation of the operation.
8484
}],
85-
/*retType=*/"::mlir::LogicalResult",
85+
/*retType=*/"LogicalResult",
8686
/*methodName=*/"getResultTilePosition",
8787
/*args=*/(ins
8888
"OpBuilder &":$b,
8989
"unsigned":$resultNumber,
9090
"ArrayRef<OpFoldResult> ":$offsets,
9191
"ArrayRef<OpFoldResult> ":$sizes,
92-
"SmallVectorImpl<OpFoldResult> &":$resultOffsets,
93-
"SmallVectorImpl<OpFoldResult> &":$resultSizes),
94-
/*methodBody=*/"",
95-
/*defaultImplementation=*/[{
96-
return failure();
97-
}]
98-
>,
99-
InterfaceMethod<
100-
/*desc=*/[{
101-
Method to return the position of iteration domain tile computed by the
102-
tiled operation.
103-
}],
104-
/*retType=*/"::mlir::LogicalResult",
105-
/*methodName=*/"getIterationDomainTileFromOperandTile",
106-
/*args=*/(ins
107-
"OpBuilder &":$b,
108-
"unsigned":$operandNumber,
109-
"ArrayRef<OpFoldResult> ":$offsets,
110-
"ArrayRef<OpFoldResult> ":$sizes,
111-
"SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
112-
"SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
92+
"SmallVector<OpFoldResult> &":$resultOffsets,
93+
"SmallVector<OpFoldResult> &":$resultSizes),
11394
/*methodBody=*/"",
11495
/*defaultImplementation=*/[{
11596
return failure();
@@ -138,7 +119,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
138119
iteration space).
139120
- `sizes` provides the size of the tile.
140121
}],
141-
/*retType=*/"FailureOr<::mlir::TilingResult>",
122+
/*retType=*/"FailureOr<TilingResult>",
142123
/*methodName=*/"generateResultTileValue",
143124
/*args=*/(ins
144125
"OpBuilder &":$b,
@@ -150,42 +131,6 @@ def TilingInterface : OpInterface<"TilingInterface"> {
150131
return failure();
151132
}]
152133
>,
153-
InterfaceMethod<
154-
/*desc=*/[{
155-
Method to generate the tiled implementation of an operation from
156-
operand tile position.
157-
158-
Generates the IR that computes the tiled implementation of an
159-
operation from operand tile. The `offsets` and `sizes`
160-
describe the tile of the operand required. This is different from
161-
`getTiledImplementation` which generates the tiled
162-
implementation of the operation given a tile of the
163-
iteration space. This method generates a tiled
164-
implementation of the operation based on the tile of the
165-
operand required. This method enables consumer fusion by using
166-
tile and fuse. The method returns failure if the operation
167-
can't be tiled to generate the operand tile. In practical terms
168-
this implies it cannot be tiled and fused with its producers.
169-
170-
- `offsets` provides the offset of the tile in the coordinate system
171-
of the original iteration space, i.e., if an iteration space
172-
dimension had non-zero offset, it must be included in the offset
173-
provided here (as opposed to zero-based offset "relative" to the
174-
iteration space).
175-
- `sizes` provides the size of the tile.
176-
}],
177-
/*retType=*/"FailureOr<::mlir::TilingResult>",
178-
/*methodName=*/"getTiledImplementationFromOperandTile",
179-
/*args=*/(ins
180-
"OpBuilder &":$b,
181-
"unsigned":$operandNumber,
182-
"ArrayRef<OpFoldResult>":$offsets,
183-
"ArrayRef<OpFoldResult>":$sizes),
184-
/*methodBody=*/"",
185-
/*defaultImplementation=*/[{
186-
return failure();
187-
}]
188-
>,
189134
InterfaceMethod<
190135
/*desc=*/[{
191136
Generates the scalar implementation of the operation.
@@ -197,7 +142,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
197142
transformations are done, this method can be used to lower to scalar
198143
code that can then be lowered to LLVM or SPIR-V dialects.
199144
}],
200-
/*retType=*/"::mlir::LogicalResult",
145+
/*retType=*/"LogicalResult",
201146
/*methodName=*/"generateScalarImplementation",
202147
/*args=*/(ins
203148
"OpBuilder &":$b,

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
24252425

24262426
LogicalResult SoftmaxOp::getResultTilePosition(
24272427
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2428-
ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
2429-
SmallVectorImpl<OpFoldResult> &resultSizes) {
2428+
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2429+
SmallVector<OpFoldResult> &resultSizes) {
24302430
if (resultNumber == 0) {
24312431
resultOffsets.assign(offsets.begin(), offsets.end());
24322432
resultSizes.assign(sizes.begin(), sizes.end());

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 26 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
110110
}));
111111
}
112112

113-
/// Instantiate the tiled implementation of the operation.
113+
// Instantiate the tiled implementation of the operation.
114114
FailureOr<TilingResult>
115115
getTiledImplementation(Operation *op, OpBuilder &b,
116116
ArrayRef<OpFoldResult> offsets,
@@ -132,66 +132,14 @@ struct LinalgOpTilingInterface
132132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
133133
}
134134

135-
void
136-
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
137-
ArrayRef<OpFoldResult> offsets,
138-
ArrayRef<OpFoldResult> sizes,
139-
SmallVectorImpl<OpFoldResult> &mappedOffsets,
140-
SmallVectorImpl<OpFoldResult> &mappedSizes) const {
141-
unsigned numLoops = linalgOp.getNumLoops();
142-
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
143-
mappedOffsets.resize(numLoops);
144-
mappedSizes.resize(numLoops);
145-
if (!indexingMap.isPermutation()) {
146-
SmallVector<Range> iterationDomain =
147-
tilingInterfaceOp.getIterationDomain(b);
148-
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
149-
mappedOffsets[index] = value.offset;
150-
mappedSizes[index] = value.size;
151-
}
152-
}
153-
for (const auto &&[index, value] :
154-
llvm::enumerate(indexingMap.getResults())) {
155-
unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
156-
mappedOffsets[dimPosition] = offsets[index];
157-
mappedSizes[dimPosition] = sizes[index];
158-
}
159-
}
160-
161-
/// Return the details of the output tile generated by the tiled
162-
/// implementation.
163-
LogicalResult getIterationDomainTileFromOperandTile(
164-
Operation *op, OpBuilder &b, unsigned operandNumber,
165-
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
166-
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
167-
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
168-
auto linalgOp = cast<LinalgOp>(op);
169-
170-
// Check that the indexing map used for the operand is a projected
171-
// permutation. This could be relaxed with a more general approach that can
172-
// map the offsets and sizes from the operand to iteration space tiles
173-
// (filling in full extent for dimensions not used to access the result).
174-
AffineMap indexingMap =
175-
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
176-
if (!indexingMap.isProjectedPermutation()) {
177-
return emitError(op->getLoc(),
178-
"unhandled get iter domain position when operand is not "
179-
"accessed using a permuted projection");
180-
}
181-
182-
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
183-
iterDomainOffsets, iterDomainSizes);
184-
return success();
185-
}
186-
187-
/// Return the details of the output tile generated by the tiled
188-
/// implementation.
135+
// Return the details of the output tile generated by the tiled
136+
// implementation.
189137
LogicalResult
190138
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
191139
ArrayRef<OpFoldResult> offsets,
192140
ArrayRef<OpFoldResult> sizes,
193-
SmallVectorImpl<OpFoldResult> &resultOffsets,
194-
SmallVectorImpl<OpFoldResult> &resultSizes) const {
141+
SmallVector<OpFoldResult> &resultOffsets,
142+
SmallVector<OpFoldResult> &resultSizes) const {
195143
Location loc = op->getLoc();
196144
LinalgOp linalgOp = cast<LinalgOp>(op);
197145

@@ -212,21 +160,6 @@ struct LinalgOpTilingInterface
212160
return success();
213161
}
214162

215-
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
216-
Operation *op, OpBuilder &b, unsigned operandNumber,
217-
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
218-
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
219-
auto tilingInterfaceOp = cast<TilingInterface>(op);
220-
if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
221-
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
222-
return emitError(
223-
op->getLoc(),
224-
"unable to obtain the iter domain position of the operation.");
225-
}
226-
return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
227-
mappedSizes);
228-
}
229-
230163
FailureOr<TilingResult>
231164
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
232165
ArrayRef<OpFoldResult> offsets,
@@ -244,16 +177,29 @@ struct LinalgOpTilingInterface
244177
"unhandled tiled implementation generation when result is not "
245178
"accessed using a permuted projection");
246179
}
247-
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
248-
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
249-
mappedOffsets, mappedSizes);
250-
auto tilingInterfaceOp = cast<TilingInterface>(op);
251-
FailureOr<TilingResult> tilingResult =
252-
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
253180

254-
if (failed(tilingResult))
255-
return failure();
181+
auto numLoops = linalgOp.getNumLoops();
182+
auto tilingInterfaceOp = cast<TilingInterface>(op);
183+
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
184+
iterationTileSizes(numLoops);
185+
if (!indexingMap.isPermutation()) {
186+
SmallVector<Range> iterationDomain =
187+
tilingInterfaceOp.getIterationDomain(b);
188+
for (const auto &range : llvm::enumerate(iterationDomain)) {
189+
iterationTileOffsets[range.index()] = range.value().offset;
190+
iterationTileSizes[range.index()] = range.value().size;
191+
}
192+
}
193+
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
194+
unsigned dimPosition =
195+
cast<AffineDimExpr>(resultExpr.value()).getPosition();
196+
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
197+
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
198+
}
256199

200+
FailureOr<TilingResult> tilingResult =
201+
tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
202+
iterationTileSizes);
257203
if (tilingResult->tiledOps.size() != 1)
258204
return op->emitOpError("failed to generate tiled implementation");
259205

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
6161
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
6262
ArrayRef<OpFoldResult> offsets,
6363
ArrayRef<OpFoldResult> sizes,
64-
SmallVectorImpl<OpFoldResult> &resultOffsets,
65-
SmallVectorImpl<OpFoldResult> &resultSizes) const {
64+
SmallVector<OpFoldResult> &resultOffsets,
65+
SmallVector<OpFoldResult> &resultSizes) const {
6666
resultOffsets.assign(offsets.begin(), offsets.end());
6767
resultSizes.assign(sizes.begin(), sizes.end());
6868
return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
199199
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
200200
ArrayRef<OpFoldResult> offsets,
201201
ArrayRef<OpFoldResult> sizes,
202-
SmallVectorImpl<OpFoldResult> &resultOffsets,
203-
SmallVectorImpl<OpFoldResult> &resultSizes) const {
202+
SmallVector<OpFoldResult> &resultOffsets,
203+
SmallVector<OpFoldResult> &resultSizes) const {
204204
// The iteration domain is over outer dimensions of packed layout. In this
205205
// context, the outer dimensions of `resultOffsets` are `offsets`. The
206206
// inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
452452
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
453453
ArrayRef<OpFoldResult> offsets,
454454
ArrayRef<OpFoldResult> sizes,
455-
SmallVectorImpl<OpFoldResult> &resultOffsets,
456-
SmallVectorImpl<OpFoldResult> &resultSizes) const {
455+
SmallVector<OpFoldResult> &resultOffsets,
456+
SmallVector<OpFoldResult> &resultSizes) const {
457457
resultOffsets = llvm::to_vector(offsets);
458458
resultSizes = llvm::to_vector(sizes);
459459
return success();

0 commit comments

Comments
 (0)