-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Block pack matmul pass #89782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Pack a matmul MxNxK operation into blocked layout as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]. The result is unpacked back to the original layout. Matmul packing splits the operands into smaller blocks (inner dimensions) and then block-transposes the block sub-groups (outer dimensions). This data arrangement minimizes distance between consecutive blocks which improves spacial locality and cache behavior.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Adam Siemieniuk (adam-smnk) ChangesPack a matmul MxNxK operation into blocked layout Matmul packing splits the operands into smaller blocks (inner dimensions) and then block-transposes the block sub-groups (outer dimensions). This data arrangement minimizes distance between consecutive blocks which improves spacial locality and cache behavior. Full diff: https://github.com/llvm/llvm-project/pull/89782.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..d4361c70468bdb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,24 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
];
}
+def LinalgPackMatmul : Pass<"linalg-pack-matmul"> {
+ let summary = "Convert linalg matmul ops to block layout and back";
+ let description = [{
+ Pack a matmul MxNxK operation into blocked layout
+ as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
+ The result is unpacked back to the original layout.
+
+ Matmul packing splits the operands into smaller blocks (inner dimensions)
+ and then block-transposes the block sub-groups (outer dimensions).
+
+ This data arrangement minimizes distance between consecutive blocks
+ which improves spacial locality and cache behavior.
+ }];
+ let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
+ let options = [
+ ListOption<"blockFactors", "block-factors", "int64_t",
+ "Block factors (mb, nb, kb) for relayout">
+ ];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5ecf84fa9c7012..2bb9277cc7b27e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1628,6 +1628,10 @@ void populateSplitReductionPattern(
void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
bool transposeLHS = true);
+/// Patterns to pack Linalg matmul ops.
+void populatePackMatmulPatterns(RewritePatternSet &patterns,
+ ArrayRef<int64_t> blockingFactors);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ee6e391d0cc682..e9b104ea5aeb58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
TransposeMatmul.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
+ PackMatmul.cpp
Padding.cpp
Promotion.cpp
Specialize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
new file mode 100644
index 00000000000000..304de03a343fdc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
@@ -0,0 +1,177 @@
+//===- PackMatmul.cpp - Linalg matmul packing -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGPACKMATMUL
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static std::optional<int64_t> getConstantRange(const Range &range) {
+ std::optional<int64_t> stride = getConstantIntValue(range.stride);
+ if (!stride || *stride != 1)
+ return std::nullopt;
+ std::optional<int64_t> offset = getConstantIntValue(range.offset);
+ if (!offset)
+ return std::nullopt;
+ std::optional<int64_t> size = getConstantIntValue(range.size);
+ if (!size)
+ return std::nullopt;
+ return (*size - *offset);
+}
+
+static bool validateFullTilesOnDims(TilingInterface tileOp,
+ ArrayRef<OpFoldResult> tiles,
+ ArrayRef<size_t> dims) {
+ if (dims.size() != tiles.size() || tiles.empty())
+ return false;
+
+ OpBuilder builder(tileOp);
+ OpBuilder::InsertionGuard guard(builder);
+ SmallVector<Range> iterationDomain =
+ cast<TilingInterface>(tileOp.getOperation()).getIterationDomain(builder);
+
+ for (auto dim : llvm::enumerate(dims)) {
+ if (dim.value() >= iterationDomain.size())
+ return false;
+
+ auto tileSize = getConstantIntValue(tiles[dim.index()]);
+ auto rangeOnDim = getConstantRange(iterationDomain[dim.value()]);
+
+ // If the tile factor or the range are non-constant, the tile size is
+ // considered to be invalid.
+ if (!tileSize || !rangeOnDim)
+ return false;
+
+ // The dimension must be fully divisible by the tile.
+ if (*rangeOnDim % *tileSize != 0)
+ return false;
+ }
+
+ return true;
+}
+
+static FailureOr<linalg::LinalgOp>
+packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+ ArrayRef<OpFoldResult> mnkTiles) {
+ if (!(isa<linalg::MatmulOp>(matmulOp) ||
+ isa<linalg::BatchMatmulOp>(matmulOp))) {
+ return rewriter.notifyMatchFailure(matmulOp, "not a matmul-like operation");
+ }
+
+ if (mnkTiles.size() != 3)
+ return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
+
+ if (matmulOp.hasDynamicShape())
+ return rewriter.notifyMatchFailure(matmulOp, "require static shape");
+
+ if (matmulOp.hasPureBufferSemantics())
+ return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics");
+
+ SmallVector<size_t, 3> dims{0, 1, 2};
+ // Skip the batch dimension if present.
+ bool isBatchMatmulOp = isa<linalg::BatchMatmulOp>(matmulOp);
+ if (isBatchMatmulOp)
+ dims = {1, 2, 3};
+
+ if (!validateFullTilesOnDims(cast<TilingInterface>(matmulOp.getOperation()),
+ mnkTiles, dims)) {
+ return rewriter.notifyMatchFailure(matmulOp,
+ "expect packing full tiles only");
+ }
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ // The op is replaced, we need to set the insertion point after it.
+ rewriter.setInsertionPointAfter(matmulOp);
+
+ auto packedCanonicalMatmul = packMatmulGreedily(
+ rewriter, matmulOp, mnkTiles, /*mnkPaddedSizesNextMultipleOf=*/{},
+ /*mnkOrder=*/{0, 1, 2});
+ if (failed(packedCanonicalMatmul))
+ return failure();
+
+ assert(packedCanonicalMatmul->packOps.size() == 3 && "failed matmul packing");
+ assert(packedCanonicalMatmul->unPackOps.size() == 1 &&
+ "failed matmul unpacking");
+
+ SmallVector<int64_t> innerPerm = {1, 0};
+ SmallVector<int64_t> outerPerm = {1, 0};
+ // Leave the batch dimension as is.
+ if (isBatchMatmulOp)
+ outerPerm = {0, 2, 1};
+
+ auto packedMatmul =
+ packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
+ packedCanonicalMatmul->packedLinalgOp,
+ /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
+ if (failed(packedMatmul))
+ return failure();
+
+ return packedMatmul->transposedLinalgOp;
+}
+
+namespace {
+template <typename OpTy>
+struct PackMatmul : public OpRewritePattern<OpTy> {
+ PackMatmul(MLIRContext *context, ArrayRef<int64_t> blockFactors,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), blockFactors(blockFactors) {}
+
+ LogicalResult matchAndRewrite(OpTy matmulOp,
+ PatternRewriter &rewriter) const override {
+ if (blockFactors.empty())
+ return failure();
+ auto packedMatmul =
+ packMatmulOp(rewriter, matmulOp,
+ getAsOpFoldResult(rewriter.getI64ArrayAttr(blockFactors)));
+ if (failed(packedMatmul))
+ return failure();
+ return success();
+ }
+
+private:
+ SmallVector<int64_t> blockFactors;
+};
+
+// Entry point for packing matmul operations.
+// Pack MatmulOp as following:
+// [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+// Pack a BatchMatmulOp as following:
+// [B][MB][NB][mb][nb] += [B][MB][KB][mb][kb] * [B][NB][KB][kb][nb]
+struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> {
+ using LinalgPackMatmulBase::LinalgPackMatmulBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(&getContext());
+ linalg::populatePackMatmulPatterns(patterns, blockFactors);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void linalg::populatePackMatmulPatterns(RewritePatternSet &patterns,
+ ArrayRef<int64_t> blockFactors) {
+ patterns.add<PackMatmul<linalg::MatmulOp>, PackMatmul<linalg::BatchMatmulOp>>(
+ patterns.getContext(), blockFactors);
+}
diff --git a/mlir/test/Dialect/Linalg/pack-matmul.mlir b/mlir/test/Dialect/Linalg/pack-matmul.mlir
new file mode 100644
index 00000000000000..d7023cfc30559b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/pack-matmul.mlir
@@ -0,0 +1,140 @@
+// RUN: mlir-opt %s -linalg-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
+
+func.func @block_matmul(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME: into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
+// CHECK-SAME: into %[[BUF1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[BUF2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%[[PACK0]], %[[PACK1]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[PACK2]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_constant(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %cst = arith.constant dense<0.0> : tensor<128x128xf32>
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%cst : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_constant(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF_RES:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
+// CHECK-DAG: %[[BUF_OUT:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[BUF_OUT]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_producer(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %1 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_producer(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[BUF_RES:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[FILL]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_consumer(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %0 = tensor.empty() : tensor<128x128xf32>
+ %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %2 = linalg.add ins(%1, %arg3 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %2 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_consumer(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG3:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF:.+]] = tensor.empty() : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: outs({{.*}} : tensor<4x8x32x16xf32>)
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: %[[OUT:.+]] = linalg.add
+// CHECK-SAME: ins(%[[UNPACK]], %[[ARG3]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[BUF]] : tensor<128x128xf32>)
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_batch_matmul(
+ %arg0: tensor<512x64x128xf32>, %arg1: tensor<512x128x64xf32>, %arg2: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+ %0 = tensor.empty() : tensor<512x64x64xf32>
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
+ outs(%arg2 : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+ return %1 : tensor<512x64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d6, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
+
+// CHECK-LABEL: func @block_batch_matmul(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<512x64x128xf32>, %[[ARG1:.+]]: tensor<512x128x64xf32>, %[[ARG2:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 64]
+// CHECK-SAME: into %[[BUF0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[BUF1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
+// CHECK-SAME: into %[[BUF1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[BUF2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[PACK0]], %[[PACK1]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[PACK2]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[OUT]] : tensor<512x64x64xf32>
|
FYI @rengolin |
This is a packing scheme aimed at CPUs to improve performance of GEMM computations by improving data accesses. The transformation has been tested within our downstream project (TPP-MLIR). |
Hi, what is the difference with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change. We have been doing something similar downstream as well.
@yzhang93 can you help review this and guide it based on the use case we have worked through.
if (isBatchMatmulOp) | ||
dims = {1, 2, 3}; | ||
|
||
if (!validateFullTilesOnDims(cast<TilingInterface>(matmulOp.getOperation()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems pessimistic. Pack operations have padding semantics, so we dont need to have full tiles only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is part of the optionality that would be interesting to have (padding vs non-padding)...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enabled padding by default but it can be optionally disabled
// [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb] | ||
// Pack a BatchMatmulOp as following: | ||
// [B][MB][NB][mb][nb] += [B][MB][KB][mb][kb] * [B][NB][KB][kb][nb] | ||
struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont see much value in a pass in MLIR. I think what is more valuable is the packMatmulOp
method here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pass is a way to get into the functionality and to test it without using the transform dialect. We see value in it to compose with our pipelines and I see no cost in keeping it upstream. Just having transforms for now is very restrictive and we don't want to start a discussion between transforms and passes in this (and subsequent) PRs, that is an orthogonal issue.
This PR and the following ones will add passes to drive the framework because that's what we have. After we merge the core functionality we can revisit and add them to upstream transforms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A huge +1 to having functionality and test without transform dialect. My operation mode has been
- Add a API entry point for the transformation function
- Optionally wrap the transformations within a pattern and have a "populate*Pattern" method
- Have a test pass to test the functionality. I am not super opposed to having a pass though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good approach. The pass will help us upstream all of the functionality in smaller steps, rebase and recompose our pipeline to make sure we didn't lose anything upstream. Then iterate until complete.
We should also discuss in separate how to compose the functionality with other compiler pipelines (like IREE), so that we can reuse the same logic everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opted to leave the pass primarily to allow testing and as a usage example.
It could be a test pass but might as well just be a normal pass.
|
||
void linalg::populatePackMatmulPatterns(RewritePatternSet &patterns, | ||
ArrayRef<int64_t> blockFactors) { | ||
patterns.add<PackMatmul<linalg::MatmulOp>, PackMatmul<linalg::BatchMatmulOp>>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually looking at this, my initial thought is that we dont need the patterns itself, but I am ok either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left it for completeness with the pass but potentially both could be removed. No strong opinions here.
} | ||
|
||
static FailureOr<linalg::LinalgOp> | ||
packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we generalize this a bit? The packMatmulGreedily
method and the packTranspose
methods expose a lot of options. Can we add a struct that can carry all these options and call these methods (instead of just mnkTiles
sizes). We can make this an API entry point. This IMO more valuable than the "populatePatternsMethod" and the pass added here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exposed more options and added more layout controls.
Both share the same core logic under the hood. I imagine you could express this pass as a transform sequence.
I think the two serve a bit different purposes. |
For folks less familiar with TPP, could you compare this to rewriting And could you include pseudo-IR before and after in the summary? |
I think this pass may not be general enough for a broader audience . It limits the usage to linalg.matmul/linalg.batch_matmul ops (without any linalg.generic variants) and fixes the innerPerm/outerPerm to {1, 0}. In our downstream development, I found it very fixable to directly call |
What does |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that there's no other documentation attached to this PR, I'm using these tests to infer the high level logic. That's a bit tricky ATM - it could be simplified with more meaningful var names (MLIR and LIT) ;-)
For example:
func.func @block_matmul(%lhs: tensor<128x128xf32>, %rhs: tensor<128x128xf32>, %out: tensor<128x128xf32>) -> tensor<128x128xf32> {
%0 = linalg.matmul ins(%lhs, %rhs : tensor<128x128xf32>, tensor<128x128xf32>)
outs(%our : tensor<128x128xf32>) -> tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}
or
func.func @block_matmul(%A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
%0 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}
And then in LIT:
// CHECK: %[[LHS_PACKED:.+]] = tensor.pack %[[LHS]]
Or something similar. I don't really mind what the actual names are, as long as they help to map to the high level semantics of Matmul. Updated names will really help to review this, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing, I'll also improve description as suggested.
An example with mmt4d
should help a lot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improved the description and test naming. Hopefully, now they are more understandable.
I'll start with generalizing and exposing the API as Mahesh suggested. But I think there's still value in providing tools for the well defined named ops. Patterns for common linalg.generic variants could be added later too. But if you need to go with truly generic representations, you probably need custom logic anyway.
The aim is not necessarily to create a fully generic infrastructure. After all there are existing tools for that. Although, if we find enough common patterns then even better. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for the contribution! This is exciting! A couple of comments:
- Why are you framing this as a CPU-only transformation? Are there any implicit limitations that we should be aware of?
- How configurable the layout would be? Could we do plain transposes, decide which operands to transpose and how many blocking levels to use? (Sorry if I missed something while skimming through the code)
} | ||
|
||
static FailureOr<linalg::LinalgOp> | ||
packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
AFAIK, it's not necessarily beneficial to pack for GPU to give a counter example. We've also used packing just for CPUs so, it's the main motivation or use case I can confidently bring up.
This one worked best based on our use cases. I'll try to generalize the pass and see what can be done. |
} | ||
|
||
if (mnkTiles.size() != 3) | ||
return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation only takes 1 set of mnkTiles
and try to pack all matmul ops with the given tile sizes.
Will we consider the use cases where we wish to pack different matmul ops into different tile sizes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added control callback that can specify packing sizes and layout configuration separately for each op. Hopefully, this addresses your use case.
Thanks for all the feedback everyone. Changes:
|
@MaheshRavishankar @yzhang93 Does the more generalized API fit your use cases better now? There are surely still limitation when to comes to more complex generic or arbitrary layouts. Hopefully the API is flexible enough to allow future extensions for the former. The latter might be out of scope here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks reasonable to me, except for having a pass. Ill approve it but would really appreciate if we drop the pass from core and move it to tests.
@@ -141,4 +141,63 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter | |||
]; | |||
} | |||
|
|||
def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not going to hold on to it, but I really dislike having such passes. I think this is just dead weight, and can never fit all downstream uses, and really we should need just transformation methods/patterns and let the passes live downstream and just have a test pass for testing. Such passes in core do not age well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such passes in core do not age well.
That's true, but I think it's orthogonal. We need the functionality upstream so that we can all use and share, and it will only age if no one uses. Being in a test pass or a transform or a dialect pass won't change much that equation.
I'd love if IREE and other compilers could bring more functionality to this pass and start using the upstream stuff (parametrized, cost modelled, etc). This is what makes them not age. If we don't provide a way for people to use in their projects, then MLIR gets less "batteries" (as @joker-eph usually say).
Thanks again for working on this! This landed within an hour of Mahesh approving ... Could you allow a bit of extra time for other reviewers to take a look and to rubber stamp it? ;-) Thank you 🙏🏻 |
Well, the PR is 2 weeks old and the last update was 2 days ago, so there was plenty of time for reviews. LLVM doesn't really have the concept of rubber stamping, one approval is enough, so what we do is just iterate in-tree. Post commit reviews are pretty common. It'd be great if you can try it out and bring issues, PRs and ideas from your side. We want this functionality to work for everyone! |
This PR was marked as draft for most of that time, so I assumed that there was no rush and also that the PR wasn't ready for review. From LLVM docs https://llvm.org/docs/CodeReview.html#lgtm-how-a-patch-is-accepted:
There's many ways to read it. For a bigger change like this one, I really appreciate when people leave ~24hrs between the first approval and landing in-tree. Not a formal requirement nor policy, just a kind request. Please don't take this the wrong way, I am OK with this change and support the effort. All my comments have been addressed. EDIT
From the initial submission it wasn't clear to me that this was targeting
We are also working on enabling this for scalable vectors. Here's our meta-ticket in IREE (the link will take you to a mini RFC in that ticket): Admittedly, our contributions might be less relevant to folks focusing on fixed-width vectors. |
This is not related (yet) to As discussed above, if we have more users (you, IREE, us), then the code will continue relevant, so it'd be really good to get your feedback and hopefully we can adapt to your usage as well as ours. |
Pack a matmul MxNxK operation into 4D blocked layout. Any present batch dimensions remain unchanged and the result is unpacked back to the original layout.
Matmul block packing splits the operands into major blocks (outer dimensions) and minor blocks (inner dimensions). The desired block layout can be controlled through packing options.