Skip to content

Commit 4c3db25

Browse files
authored
[mlir][linalg] Block pack matmul pass (llvm#89782)
Pack a matmul MxNxK operation into 4D blocked layout. Any present batch dimensions remain unchanged and the result is unpacked back to the original layout. Matmul block packing splits the operands into major blocks (outer dimensions) and minor blocks (inner dimensions). The desired block layout can be controlled through packing options.
1 parent 8ed7ea0 commit 4c3db25

File tree

7 files changed

+1106
-0
lines changed

7 files changed

+1106
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,63 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
141141
];
142142
}
143143

144+
def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> {
145+
let summary = "Convert linalg matmul ops to block layout and back";
146+
let description = [{
147+
Pack a matmul operation into blocked layout with two levels of subdivision:
148+
- major 2D blocks - outer dimensions, consist of minor blocks
149+
- minor 2D blocks - inner dimensions, consist of scalar elements
150+
151+
A 2D matmul MxNxK gets reshaped into blocked 4D representation
152+
as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][nb][kb]
153+
where the (MB, NB, KB) dimensions represent the major blocks,
154+
and the (mb, nb, kb) are the minor blocks of their respective
155+
original 2D dimensions (M, N, K).
156+
157+
Depending on the initial operands' data layout and the specified
158+
packing options, the major blocks dimensions might get transposed
159+
e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed
160+
e.g., [mb][kb] -> [kb][mb].
161+
Any present batch dimensions remain unchanged.
162+
The final result is unpacked back to the original shape.
163+
164+
For example, given a matmul operation:
165+
```mlir
166+
%res = linalg.matmul ins(%A, %B) outs(%C)
167+
```
168+
the default transformation result can be represented as:
169+
```mlir
170+
%A_packed = pack %A : 2D <MxK> -> 4D <MBxKBxmbxkb>
171+
%B_packed = pack %B : 2D <KxN> -> 4D <NBxKBxnbxkb>
172+
%C_packed = pack %C : 2D <MxN> -> 4D <MBxNBxmbxnb>
173+
%res_packed = linalg.mmt4d ins(%A_packed, %B_packed) outs(%C_packed)
174+
%res = unpack %res_packed : 4D <MBxNBxmbxnb> -> 2D <MxN>
175+
```
176+
}];
177+
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
178+
let options = [
179+
ListOption<"blockFactors", "block-factors", "int64_t",
180+
"Block factors (mb, nb, kb) for relayout">,
181+
Option<"allowPadding", "allow-padding", "bool",
182+
/*default=*/"true",
183+
"Allow packing padding">,
184+
ListOption<"mnkPaddedSizesNextMultipleOf", "mnk-padded-multiples", "int64_t",
185+
"Next multiples of the packing sizes">,
186+
ListOption<"mnkOrder", "mnk-order", "int64_t",
187+
"Permutation of matmul (M, N, K) dimensions order">,
188+
Option<"lhsTransposeOuterBlocks", "lhs-transpose-outer-blocks", "bool",
189+
/*default=*/"false",
190+
"Transpose LHS outer block layout [MB][KB] -> [KB][MB]">,
191+
Option<"lhsTransposeInnerBlocks", "lhs-transpose-inner-blocks", "bool",
192+
/*default=*/"false",
193+
"Transpose LHS inner block layout [mb][kb] -> [kb][mb]">,
194+
Option<"rhsTransposeOuterBlocks", "rhs-transpose-outer-blocks", "bool",
195+
/*default=*/"true",
196+
"Transpose RHS outer block layout [KB][NB] -> [NB][KB]">,
197+
Option<"rhsTransposeInnerBlocks", "rhs-transpose-inner-blocks", "bool",
198+
/*default=*/"true",
199+
"Transpose RHS inner block layout [kb][nb] -> [nb][kb]">
200+
];
201+
}
202+
144203
#endif // MLIR_DIALECT_LINALG_PASSES

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,66 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
11621162
ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
11631163
ArrayRef<int64_t> mnkOrder);
11641164

1165+
struct BlockPackMatmulOptions {
1166+
/// Minor block factors (mb, nb, kb) for packing relayout where mb, mn are
1167+
/// the parallel dimensions and kb is the reduction dimension.
1168+
SmallVector<int64_t, 3> blockFactors;
1169+
1170+
/// If true, allows packing of dimensions that only partially fit into the
1171+
/// block factors.
1172+
bool allowPadding = true;
1173+
1174+
/// Next multiples of the packing sizes.
1175+
SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf;
1176+
1177+
/// Permutation of matmul (M, N, K) dimensions order.
1178+
SmallVector<int64_t, 3> mnkOrder = {0, 1, 2};
1179+
1180+
/// Transpose LHS outer block layout [MB][KB] -> [KB][MB].
1181+
bool lhsTransposeOuterBlocks = false;
1182+
1183+
/// Transpose LHS inner block layout [mb][kb] -> [kb][mb].
1184+
bool lhsTransposeInnerBlocks = false;
1185+
1186+
/// Transpose RHS outer block layout [KB][NB] -> [NB][KB].
1187+
bool rhsTransposeOuterBlocks = true;
1188+
1189+
/// Transpose RHS inner block layout [kb][nb] -> [nb][kb].
1190+
bool rhsTransposeInnerBlocks = true;
1191+
};
1192+
1193+
/// Function type which is used to control matmul packing.
1194+
/// It is expected to return valid packing configuration for each operation.
1195+
/// Lack of packing options indicates that no valid configuration could be
1196+
/// assigned and the operation will not be packed.
1197+
using ControlBlockPackMatmulFn =
1198+
std::function<std::optional<BlockPackMatmulOptions>(linalg::LinalgOp)>;
1199+
1200+
/// Pack a matmul operation into blocked 4D layout.
1201+
///
1202+
/// Relayout a matmul operation into blocked layout with two levels of
1203+
/// subdivision:
1204+
/// - major 2D blocks - outer dimensions, consist of minor blocks
1205+
/// - minor 2D blocks - inner dimensions, consist of scalar elements
1206+
///
1207+
/// A 2D matmul MxNxK gets reshaped into blocked 4D representation
1208+
/// as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][nb][kb]
1209+
/// where the (MB, NB, KB) dimensions represent the major blocks,
1210+
/// and the (mb, nb, kb) are the minor blocks of their respective
1211+
/// original 2D dimensions (M, N, K).
1212+
///
1213+
/// Depending on the initial operands' data layout and the specified
1214+
/// packing options, the major blocks dimensions might get transposed
1215+
/// e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed
1216+
/// e.g., [mb][kb] -> [kb][mb].
1217+
/// Any present batch dimensions remain unchanged.
1218+
/// The final result is unpacked back to the original shape.
1219+
///
1220+
/// Return failure if no valid packing options are provided.
1221+
FailureOr<PackResult>
1222+
blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
1223+
const ControlBlockPackMatmulFn &controlPackMatmul);
1224+
11651225
/// Rewrite tensor.from_elements to linalg.generic.
11661226
FailureOr<Operation *>
11671227
rewriteInDestinationPassingStyle(RewriterBase &rewriter,
@@ -1628,6 +1688,10 @@ void populateSplitReductionPattern(
16281688
void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
16291689
bool transposeLHS = true);
16301690

1691+
/// Patterns to block pack Linalg matmul ops.
1692+
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
1693+
const ControlBlockPackMatmulFn &controlFn);
1694+
16311695
} // namespace linalg
16321696
} // namespace mlir
16331697

0 commit comments

Comments
 (0)