Skip to content

Commit 996d7a9

Browse files
committed
[mlir][linalg] Enable fuse consumer
This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp. - Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling interface which get iteration domain position from operand position. - Add interface method 'getTiledImplementationFromOperandPosition' to tiling interface which generate tiled implementation according to operand position. - Implemented the above two methods and supported consumer fusion for FuseIntoContainingOp.
1 parent cbcdf12 commit 996d7a9

File tree

5 files changed

+590
-103
lines changed

5 files changed

+590
-103
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -310,51 +310,61 @@ def FuseIntoContainingOp :
310310
["allowsRepeatedHandleOperands"]>,
311311
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
312312
ReportTrackingListenerFailuresOpTrait]> {
313-
let summary = "Fuse a producer into a containing operation.";
313+
let summary = "Fuse a target into a containing operation.";
314314

315315
let description = [{
316-
Fuses the `producer_op` into the `containing_op`.
316+
Fuses the `target_op` into the `containing_op`.
317317
Returns a handle to the fused ops and the `new_containing_op`.
318318

319-
The producer is typically a slice of a tileable op (i.e., implements
320-
TilingInterface). In that case, this transform computes the accessed
321-
producer slice inside of the containing op ("tile and fuse") and if required,
322-
creates a new containing op with outputs from the fused producer. Otherwise,
323-
the entire producer is cloned inside the containing op ("clone and fuse").
319+
This operation supports fusion of producer or fusion of consumer. We will
320+
refer to the value connecting the containing operation and the target
321+
operation as the "bridge" below.
322+
323+
When fuse producer, the bridge is typically a slice of a tileable op (i.e.,
324+
implements TilingInterface). In that case, this transform computes the
325+
accessed bridge slice inside of the containing op ("tile and fuse") and
326+
if required, creates a new containing op with outputs from the fused target.
327+
Otherwise, the entire target is cloned inside the containing op ("clone
328+
and fuse").
329+
330+
When fuse consumer, the bridge is the result of containing op and a operand
331+
of a tileable op (i.e., implements TilingInterface). In this case, this
332+
transform computes the access bridge slice inside the containing op ("tile
333+
and fuse") and creates a new containing op with consumer's output.
324334

325335
The containing op handle must be associated with exactly one payload op. The
326-
producer op handle may be associated with multiple payload ops. This
327-
transform fuses producers one-by-one, always picking an unspecified producer
336+
target op handle may be associated with multiple payload ops. This
337+
transform fuses targets one-by-one, always picking an unspecified target
328338
that has at least one use inside the containing op among the
329-
producers. A producer can be listed multiple times in the handle.
339+
targets. A target can be listed multiple times in the handle.
330340

331-
Note: If a producer has multiple uses inside the containing op, it is
341+
Note: If a target has multiple uses inside the containing op, it is
332342
currently tiled and/or cloned multiple times into the containing op.
333343
TODO: Reuse already fused OpResults instead of tiling/cloning a second time
334-
when possible. Fuse producers according to a topological sorting to achieve
344+
when possible. Fuse targets according to a topological sorting to achieve
335345
the largest amount of reuse.
336346

337347
#### Return modes
338348

339-
If at least one producer could not be fused, this operation produces a
349+
If at least one target could not be fused, this operation produces a
340350
silenceable failure. This is the case when tiling fails or when no
341-
producer op could be found among the remaining producers that has at least
342-
one use within the containing op. I.e., "producers" that are not consumed
351+
target op could be found among the remaining targets that has at least
352+
one use within the containing op. I.e., "targets" that are not consumed
343353
within the containing op are rejected by this operation.
344354

345-
This operation consumes the producer handle.
355+
This operation consumes the target handle.
346356
This operation only reads the containing op handle.
347357
}];
348358

349-
let arguments = (ins TransformHandleTypeInterface:$producer_op,
359+
let arguments = (ins TransformHandleTypeInterface:$target_op,
350360
TransformHandleTypeInterface:$containing_op);
351361
let results = (outs TransformHandleTypeInterface:$fused_op,
352362
TransformHandleTypeInterface:$new_containing_op);
353-
let assemblyFormat = "$producer_op `into` $containing_op attr-dict "
363+
let assemblyFormat = "$target_op `into` $containing_op attr-dict "
354364
" `:` functional-type(operands, results)";
355365

356366
let builders = [
357-
OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)>
367+
OpBuilder<(ins "Value":$targetOp, "Value":$containingOp)>
358368
];
359369
}
360370

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
7474
return {};
7575
}]
7676
>,
77+
InterfaceMethod<
78+
/*desc=*/[{
79+
Method to return iterator domain position computed by the
80+
input operand position.
81+
}],
82+
/*retType=*/"LogicalResult",
83+
/*methodName=*/"getIterDomainTilePositionFromOperandPosition",
84+
/*args=*/(ins
85+
"OpBuilder &":$b,
86+
"unsigned":$operandNumber,
87+
"ArrayRef<OpFoldResult> ":$offsets,
88+
"ArrayRef<OpFoldResult> ":$sizes,
89+
"SmallVector<OpFoldResult> &":$iterDomainOffsets,
90+
"SmallVector<OpFoldResult> &":$iterDomainSizes),
91+
/*methodBody=*/"",
92+
/*defaultImplementation=*/[{
93+
return failure();
94+
}]
95+
>,
7796
InterfaceMethod<
7897
/*desc=*/[{
7998
Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
96115
return failure();
97116
}]
98117
>,
118+
InterfaceMethod<
119+
/*desc=*/[{
120+
Method to generate the tiled implementation of an operation from
121+
operand position.
122+
123+
Generates the IR that generate the tiled implementation of an
124+
operation from operand position. The `offsets` and `sizes`
125+
describe the tile of the operand required. This is different from
126+
`getTiledImplementation` which generates the tiled
127+
implementation of the operation given a tile of the
128+
iteration space. This method generates a tiled
129+
implementation of the operation based on the position of the
130+
operand required. This method enables fusion consumer by using
131+
tile and fuse. The method returns failure if the operation
132+
can't be tiled to generate the operand tile. In practical terms
133+
this implies it cannot be tiled and fused with its producers.
134+
135+
- `offsets` provides the offset of the tile in the coordinate system
136+
of the original iteration space, i.e., if an iteration space
137+
dimension had non-zero offset, it must be included in the offset
138+
provided here (as opposed to zero-based offset "relative" to the
139+
iteration space).
140+
- `sizes` provides the size of the tile.
141+
}],
142+
/*retType=*/"FailureOr<TilingResult>",
143+
/*methodName=*/"getTiledImplementationFromOperandPosition",
144+
/*args=*/(ins
145+
"OpBuilder &":$b,
146+
"unsigned":$operandNumber,
147+
"ArrayRef<OpFoldResult>":$offsets,
148+
"ArrayRef<OpFoldResult>":$sizes),
149+
/*methodBody=*/"",
150+
/*defaultImplementation=*/[{
151+
return failure();
152+
}]
153+
>,
99154
InterfaceMethod<
100155
/*desc=*/[{
101156
Method to generate the code that produces a tile of the result.

0 commit comments

Comments
 (0)