Skip to content

Commit 5443743

Browse files
[mlir][Linalg] Add a transform.structured.pack operation
This revision introduces a `transform.structured.pack` operation to transform any Linalg operation to a higher-dimensional Linalg operation on packed operands. `tensor.pack` (resp. `tensor.unpack`) operations are inserted for the operands (resp. results) that need to be packed (resp. unpacked) according to the `packed_sizes` specification. At the moment, the packing operation always pads with `getZeroAttr` which will need to be adjusted depending on the consumers. Packing is limited to those dimensions that are indexed only by AffineDimExpr. Packing more advanced indexings requires modular arithmetic that is outside the scoped of a `linalg.generic` at the moment. Differential Revision: https://reviews.llvm.org/D141860
1 parent 78ba3e7 commit 5443743

File tree

3 files changed

+890
-70
lines changed

3 files changed

+890
-70
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,90 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
341341
}];
342342
}
343343

344+
def PackOp : Op<Transform_Dialect, "structured.pack", [
345+
TransformOpInterface,
346+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,]> {
347+
let description = [{
348+
Pack a LinalgOp by applying a data tiling transformation on the op and
349+
packing the operands according to the `packed_sizes` specification.
350+
351+
Iterator dimensions are tiled in their canonical order in the op spec.
352+
Operands are packed according to the same canonical order of the op iterator
353+
dimensions.
354+
355+
Specifying a packed size of 0 for an iterator removes it from consideration
356+
for packing.
357+
358+
`tensor.pack` (resp. `tensor.unpack`) operations are inserted for the operands
359+
(resp. results) that need to be packed (resp. unpacked) according to the
360+
`packed_sizes` specification.
361+
362+
#### Example
363+
364+
Consider a `linalg.matmul` with indexing maps:
365+
```
366+
// M N K M K
367+
// affine_map<(d0, d1, d2) -> (d0, d2)>
368+
// K N
369+
// affine_map<(d0, d1, d2) -> (d2, d1)>
370+
// M N
371+
// affine_map<(d0, d1, d2) -> (d0, d1)>
372+
%0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
373+
outs( %C: tensor<?x?xf32>)
374+
```
375+
376+
Specifying packed_sizes [2, 3, 4] results in tiling the iterator dimensions
377+
M, N and K, in this order, in both the op and its operands.
378+
```
379+
// M N K m n k M K m k
380+
// affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
381+
// K N n k
382+
// affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
383+
// M N m n
384+
// affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
385+
%0 = linalg.generic_representing_some_higher_d_matmul
386+
ins(%A, %B: tensor<?x?x2x4xf32>, tensor<?x?x4x3xf32>)
387+
outs( %C: tensor<?x?x2x4xf32>)
388+
```
389+
In particular, note that the second operand `B` has shape `KxNxnxk` (and not
390+
`KxNxkxn` as one could expect by looking **only** at the operand).
391+
392+
Other layouts can be obtained unsurprisingly from this canonical
393+
transformation by composing the resulting operation with a (future)
394+
`transform.structured.pack_transpose` op.
395+
This composition allows separating concerns and composes better compared
396+
to adding additional permutation attributes to this transform op.
397+
398+
#### Return modes
399+
400+
This operation applies to a single Linalg op, otherwise it fails.
401+
This operation may produce a definiteFailure if the packing fails for any
402+
reason.
403+
404+
The returned handle point to the packed LinalgOp.
405+
}];
406+
407+
let arguments = (ins TransformHandleTypeInterface:$target,
408+
Variadic<PDL_Operation>:$packed_sizes,
409+
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_packed_sizes);
410+
let results = (outs TransformHandleTypeInterface:$packed_op);
411+
let assemblyFormat = [{
412+
$target
413+
`packed_sizes` `=` custom<DynamicIndexList>($packed_sizes,
414+
$static_packed_sizes)
415+
attr-dict
416+
`:` functional-type($target, results)
417+
}];
418+
419+
let extraClassDeclaration = [{
420+
::mlir::DiagnosedSilenceableFailure apply(
421+
transform::TransformResults &transformResults,
422+
transform::TransformState &state);
423+
424+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedPackedSizes();
425+
}];
426+
}
427+
344428
//===----------------------------------------------------------------------===//
345429
// PadOp
346430
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)