@@ -1162,6 +1162,66 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
1162
1162
ArrayRef<int64_t > mnkPaddedSizesNextMultipleOf,
1163
1163
ArrayRef<int64_t > mnkOrder);
1164
1164
1165
+ struct BlockPackMatmulOptions {
1166
+ // / Minor block factors (mb, nb, kb) for packing relayout where mb, mn are
1167
+ // / the parallel dimensions and kb is the reduction dimension.
1168
+ SmallVector<int64_t , 3 > blockFactors;
1169
+
1170
+ // / If true, allows packing of dimensions that only partially fit into the
1171
+ // / block factors.
1172
+ bool allowPadding = true ;
1173
+
1174
+ // / Next multiples of the packing sizes.
1175
+ SmallVector<int64_t , 3 > mnkPaddedSizesNextMultipleOf;
1176
+
1177
+ // / Permutation of matmul (M, N, K) dimensions order.
1178
+ SmallVector<int64_t , 3 > mnkOrder = {0 , 1 , 2 };
1179
+
1180
+ // / Transpose LHS outer block layout [MB][KB] -> [KB][MB].
1181
+ bool lhsTransposeOuterBlocks = false ;
1182
+
1183
+ // / Transpose LHS inner block layout [mb][kb] -> [kb][mb].
1184
+ bool lhsTransposeInnerBlocks = false ;
1185
+
1186
+ // / Transpose RHS outer block layout [KB][NB] -> [NB][KB].
1187
+ bool rhsTransposeOuterBlocks = true ;
1188
+
1189
+ // / Transpose RHS inner block layout [kb][nb] -> [nb][kb].
1190
+ bool rhsTransposeInnerBlocks = true ;
1191
+ };
1192
+
1193
+ // / Function type which is used to control matmul packing.
1194
+ // / It is expected to return valid packing configuration for each operation.
1195
+ // / Lack of packing options indicates that no valid configuration could be
1196
+ // / assigned and the operation will not be packed.
1197
+ using ControlBlockPackMatmulFn =
1198
+ std::function<std::optional<BlockPackMatmulOptions>(linalg::LinalgOp)>;
1199
+
1200
+ // / Pack a matmul operation into blocked 4D layout.
1201
+ // /
1202
+ // / Relayout a matmul operation into blocked layout with two levels of
1203
+ // / subdivision:
1204
+ // / - major 2D blocks - outer dimensions, consist of minor blocks
1205
+ // / - minor 2D blocks - inner dimensions, consist of scalar elements
1206
+ // /
1207
+ // / A 2D matmul MxNxK gets reshaped into blocked 4D representation
1208
+ // / as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][nb][kb]
1209
+ // / where the (MB, NB, KB) dimensions represent the major blocks,
1210
+ // / and the (mb, nb, kb) are the minor blocks of their respective
1211
+ // / original 2D dimensions (M, N, K).
1212
+ // /
1213
+ // / Depending on the initial operands' data layout and the specified
1214
+ // / packing options, the major blocks dimensions might get transposed
1215
+ // / e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed
1216
+ // / e.g., [mb][kb] -> [kb][mb].
1217
+ // / Any present batch dimensions remain unchanged.
1218
+ // / The final result is unpacked back to the original shape.
1219
+ // /
1220
+ // / Return failure if no valid packing options are provided.
1221
+ FailureOr<PackResult>
1222
+ blockPackMatmul (RewriterBase &rewriter, linalg::LinalgOp linalgOp,
1223
+ const ControlBlockPackMatmulFn &controlPackMatmul);
1224
+
1165
1225
// / Rewrite tensor.from_elements to linalg.generic.
1166
1226
FailureOr<Operation *>
1167
1227
rewriteInDestinationPassingStyle (RewriterBase &rewriter,
@@ -1628,6 +1688,10 @@ void populateSplitReductionPattern(
1628
1688
void populateTransposeMatmulPatterns (RewritePatternSet &patterns,
1629
1689
bool transposeLHS = true );
1630
1690
1691
+ // / Patterns to block pack Linalg matmul ops.
1692
+ void populateBlockPackMatmulPatterns (RewritePatternSet &patterns,
1693
+ const ControlBlockPackMatmulFn &controlFn);
1694
+
1631
1695
} // namespace linalg
1632
1696
} // namespace mlir
1633
1697
0 commit comments