Skip to content

[mlir][linalg] Block pack matmul pass #89782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
May 9, 2024
Merged

Conversation

adam-smnk
Copy link
Contributor

@adam-smnk adam-smnk commented Apr 23, 2024

Pack a matmul MxNxK operation into 4D blocked layout. Any present batch dimensions remain unchanged and the result is unpacked back to the original layout.

Matmul block packing splits the operands into major blocks (outer dimensions) and minor blocks (inner dimensions). The desired block layout can be controlled through packing options.

Pack a matmul MxNxK operation into blocked layout
as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
The result is unpacked back to the original layout.

Matmul packing splits the operands into smaller blocks (inner dimensions)
and then block-transposes the block sub-groups (outer dimensions).

This data arrangement minimizes distance between consecutive blocks
which improves spacial locality and cache behavior.
@llvmbot
Copy link
Member

llvmbot commented Apr 23, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Adam Siemieniuk (adam-smnk)

Changes

Pack a matmul MxNxK operation into blocked layout
as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]. The result is unpacked back to the original layout.

Matmul packing splits the operands into smaller blocks (inner dimensions) and then block-transposes the block sub-groups (outer dimensions).

This data arrangement minimizes distance between consecutive blocks which improves spacial locality and cache behavior.


Full diff: https://github.com/llvm/llvm-project/pull/89782.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+20)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp (+177)
  • (added) mlir/test/Dialect/Linalg/pack-matmul.mlir (+140)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..d4361c70468bdb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,24 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
+def LinalgPackMatmul : Pass<"linalg-pack-matmul"> {
+  let summary = "Convert linalg matmul ops to block layout and back";
+  let description = [{
+    Pack a matmul MxNxK operation into blocked layout
+    as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
+    The result is unpacked back to the original layout.
+
+    Matmul packing splits the operands into smaller blocks (inner dimensions)
+    and then block-transposes the block sub-groups (outer dimensions).
+
+    This data arrangement minimizes distance between consecutive blocks
+    which improves spacial locality and cache behavior.
+  }];
+  let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
+  let options = [
+    ListOption<"blockFactors", "block-factors", "int64_t",
+               "Block factors (mb, nb, kb) for relayout">
+  ];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5ecf84fa9c7012..2bb9277cc7b27e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1628,6 +1628,10 @@ void populateSplitReductionPattern(
 void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
                                      bool transposeLHS = true);
 
+/// Patterns to pack Linalg matmul ops.
+void populatePackMatmulPatterns(RewritePatternSet &patterns,
+                                ArrayRef<int64_t> blockingFactors);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ee6e391d0cc682..e9b104ea5aeb58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   TransposeMatmul.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
+  PackMatmul.cpp
   Padding.cpp
   Promotion.cpp
   Specialize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
new file mode 100644
index 00000000000000..304de03a343fdc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
@@ -0,0 +1,177 @@
+//===- PackMatmul.cpp - Linalg matmul packing -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGPACKMATMUL
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static std::optional<int64_t> getConstantRange(const Range &range) {
+  std::optional<int64_t> stride = getConstantIntValue(range.stride);
+  if (!stride || *stride != 1)
+    return std::nullopt;
+  std::optional<int64_t> offset = getConstantIntValue(range.offset);
+  if (!offset)
+    return std::nullopt;
+  std::optional<int64_t> size = getConstantIntValue(range.size);
+  if (!size)
+    return std::nullopt;
+  return (*size - *offset);
+}
+
+static bool validateFullTilesOnDims(TilingInterface tileOp,
+                                    ArrayRef<OpFoldResult> tiles,
+                                    ArrayRef<size_t> dims) {
+  if (dims.size() != tiles.size() || tiles.empty())
+    return false;
+
+  OpBuilder builder(tileOp);
+  OpBuilder::InsertionGuard guard(builder);
+  SmallVector<Range> iterationDomain =
+      cast<TilingInterface>(tileOp.getOperation()).getIterationDomain(builder);
+
+  for (auto dim : llvm::enumerate(dims)) {
+    if (dim.value() >= iterationDomain.size())
+      return false;
+
+    auto tileSize = getConstantIntValue(tiles[dim.index()]);
+    auto rangeOnDim = getConstantRange(iterationDomain[dim.value()]);
+
+    // If the tile factor or the range are non-constant, the tile size is
+    // considered to be invalid.
+    if (!tileSize || !rangeOnDim)
+      return false;
+
+    // The dimension must be fully divisible by the tile.
+    if (*rangeOnDim % *tileSize != 0)
+      return false;
+  }
+
+  return true;
+}
+
+static FailureOr<linalg::LinalgOp>
+packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+             ArrayRef<OpFoldResult> mnkTiles) {
+  if (!(isa<linalg::MatmulOp>(matmulOp) ||
+        isa<linalg::BatchMatmulOp>(matmulOp))) {
+    return rewriter.notifyMatchFailure(matmulOp, "not a matmul-like operation");
+  }
+
+  if (mnkTiles.size() != 3)
+    return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
+
+  if (matmulOp.hasDynamicShape())
+    return rewriter.notifyMatchFailure(matmulOp, "require static shape");
+
+  if (matmulOp.hasPureBufferSemantics())
+    return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics");
+
+  SmallVector<size_t, 3> dims{0, 1, 2};
+  // Skip the batch dimension if present.
+  bool isBatchMatmulOp = isa<linalg::BatchMatmulOp>(matmulOp);
+  if (isBatchMatmulOp)
+    dims = {1, 2, 3};
+
+  if (!validateFullTilesOnDims(cast<TilingInterface>(matmulOp.getOperation()),
+                               mnkTiles, dims)) {
+    return rewriter.notifyMatchFailure(matmulOp,
+                                       "expect packing full tiles only");
+  }
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  // The op is replaced, we need to set the insertion point after it.
+  rewriter.setInsertionPointAfter(matmulOp);
+
+  auto packedCanonicalMatmul = packMatmulGreedily(
+      rewriter, matmulOp, mnkTiles, /*mnkPaddedSizesNextMultipleOf=*/{},
+      /*mnkOrder=*/{0, 1, 2});
+  if (failed(packedCanonicalMatmul))
+    return failure();
+
+  assert(packedCanonicalMatmul->packOps.size() == 3 && "failed matmul packing");
+  assert(packedCanonicalMatmul->unPackOps.size() == 1 &&
+         "failed matmul unpacking");
+
+  SmallVector<int64_t> innerPerm = {1, 0};
+  SmallVector<int64_t> outerPerm = {1, 0};
+  // Leave the batch dimension as is.
+  if (isBatchMatmulOp)
+    outerPerm = {0, 2, 1};
+
+  auto packedMatmul =
+      packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
+                    packedCanonicalMatmul->packedLinalgOp,
+                    /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
+  if (failed(packedMatmul))
+    return failure();
+
+  return packedMatmul->transposedLinalgOp;
+}
+
+namespace {
+template <typename OpTy>
+struct PackMatmul : public OpRewritePattern<OpTy> {
+  PackMatmul(MLIRContext *context, ArrayRef<int64_t> blockFactors,
+             PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(context, benefit), blockFactors(blockFactors) {}
+
+  LogicalResult matchAndRewrite(OpTy matmulOp,
+                                PatternRewriter &rewriter) const override {
+    if (blockFactors.empty())
+      return failure();
+    auto packedMatmul =
+        packMatmulOp(rewriter, matmulOp,
+                     getAsOpFoldResult(rewriter.getI64ArrayAttr(blockFactors)));
+    if (failed(packedMatmul))
+      return failure();
+    return success();
+  }
+
+private:
+  SmallVector<int64_t> blockFactors;
+};
+
+// Entry point for packing matmul operations.
+// Pack MatmulOp as following:
+// [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+// Pack a BatchMatmulOp as following:
+// [B][MB][NB][mb][nb] += [B][MB][KB][mb][kb] * [B][NB][KB][kb][nb]
+struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> {
+  using LinalgPackMatmulBase::LinalgPackMatmulBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(&getContext());
+    linalg::populatePackMatmulPatterns(patterns, blockFactors);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void linalg::populatePackMatmulPatterns(RewritePatternSet &patterns,
+                                        ArrayRef<int64_t> blockFactors) {
+  patterns.add<PackMatmul<linalg::MatmulOp>, PackMatmul<linalg::BatchMatmulOp>>(
+      patterns.getContext(), blockFactors);
+}
diff --git a/mlir/test/Dialect/Linalg/pack-matmul.mlir b/mlir/test/Dialect/Linalg/pack-matmul.mlir
new file mode 100644
index 00000000000000..d7023cfc30559b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/pack-matmul.mlir
@@ -0,0 +1,140 @@
+// RUN: mlir-opt %s -linalg-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
+
+func.func @block_matmul(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
+// CHECK-SAME:  into %[[BUF1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[BUF2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:  ins(%[[PACK0]], %[[PACK1]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[PACK2]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_constant(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %cst = arith.constant dense<0.0> : tensor<128x128xf32>
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%cst : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_constant(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF_RES:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
+// CHECK-DAG: %[[BUF_OUT:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[BUF_OUT]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_producer(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %cst = arith.constant 0.0 : f32
+  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %1 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_producer(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[BUF_RES:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[FILL]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_consumer(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = tensor.empty() : tensor<128x128xf32>
+  %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %2 = linalg.add ins(%1, %arg3 : tensor<128x128xf32>, tensor<128x128xf32>)
+                  outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %2 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_consumer(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG3:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF:.+]] = tensor.empty() : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  outs({{.*}} : tensor<4x8x32x16xf32>)
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: %[[OUT:.+]] = linalg.add
+// CHECK-SAME:  ins(%[[UNPACK]], %[[ARG3]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[BUF]] : tensor<128x128xf32>)
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_batch_matmul(
+    %arg0: tensor<512x64x128xf32>, %arg1: tensor<512x128x64xf32>, %arg2: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+  %0 = tensor.empty() : tensor<512x64x64xf32>
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
+                           outs(%arg2 : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+  return %1 : tensor<512x64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d6, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
+
+// CHECK-LABEL: func @block_batch_matmul(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<512x64x128xf32>, %[[ARG1:.+]]: tensor<512x128x64xf32>, %[[ARG2:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[BUF0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[BUF1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
+// CHECK-SAME:  into %[[BUF1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[BUF2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[PACK0]], %[[PACK1]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[PACK2]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[OUT]] : tensor<512x64x64xf32>

@adam-smnk
Copy link
Contributor Author

FYI @rengolin

@adam-smnk
Copy link
Contributor Author

adam-smnk commented Apr 23, 2024

This is a packing scheme aimed at CPUs to improve performance of GEMM computations by improving data accesses.
The pass is primarily a driver of the existing upstream packing and tiling utilities with CPU-centric decisions.

The transformation has been tested within our downstream project (TPP-MLIR).
This is the first PR which aims to bring our Linalg tensor-level transformations upstream. Further follow-up will include pack and unpack propagation to amortize the cost of packing and VNNI layout packing which allows targeting type-aware SIMD instructions.

@nujaa
Copy link
Contributor

nujaa commented Apr 23, 2024

Hi, what is the difference with structured.pack except for the support of batch_matmul ? Could they be merged together ?

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change. We have been doing something similar downstream as well.

@yzhang93 can you help review this and guide it based on the use case we have worked through.

if (isBatchMatmulOp)
dims = {1, 2, 3};

if (!validateFullTilesOnDims(cast<TilingInterface>(matmulOp.getOperation()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems pessimistic. Pack operations have padding semantics, so we dont need to have full tiles only.

Copy link
Contributor

@dcaballe dcaballe Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is part of the optionality that would be interesting to have (padding vs non-padding)...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enabled padding by default but it can be optionally disabled

// [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
// Pack a BatchMatmulOp as following:
// [B][MB][NB][mb][nb] += [B][MB][KB][mb][kb] * [B][NB][KB][kb][nb]
struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont see much value in a pass in MLIR. I think what is more valuable is the packMatmulOp method here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pass is a way to get into the functionality and to test it without using the transform dialect. We see value in it to compose with our pipelines and I see no cost in keeping it upstream. Just having transforms for now is very restrictive and we don't want to start a discussion between transforms and passes in this (and subsequent) PRs, that is an orthogonal issue.

This PR and the following ones will add passes to drive the framework because that's what we have. After we merge the core functionality we can revisit and add them to upstream transforms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A huge +1 to having functionality and test without transform dialect. My operation mode has been

  1. Add a API entry point for the transformation function
  2. Optionally wrap the transformations within a pattern and have a "populate*Pattern" method
  3. Have a test pass to test the functionality. I am not super opposed to having a pass though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good approach. The pass will help us upstream all of the functionality in smaller steps, rebase and recompose our pipeline to make sure we didn't lose anything upstream. Then iterate until complete.

We should also discuss in separate how to compose the functionality with other compiler pipelines (like IREE), so that we can reuse the same logic everywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted to leave the pass primarily to allow testing and as a usage example.
It could be a test pass but might as well just be a normal pass.


void linalg::populatePackMatmulPatterns(RewritePatternSet &patterns,
ArrayRef<int64_t> blockFactors) {
patterns.add<PackMatmul<linalg::MatmulOp>, PackMatmul<linalg::BatchMatmulOp>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually looking at this, my initial thought is that we dont need the patterns itself, but I am ok either way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left it for completeness with the pass but potentially both could be removed. No strong opinions here.

}

static FailureOr<linalg::LinalgOp>
packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we generalize this a bit? The packMatmulGreedily method and the packTranspose methods expose a lot of options. Can we add a struct that can carry all these options and call these methods (instead of just mnkTiles sizes). We can make this an API entry point. This IMO more valuable than the "populatePatternsMethod" and the pass added here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exposed more options and added more layout controls.

@adam-smnk
Copy link
Contributor Author

Hi, what is the difference with structured.pack except for the support of batch_matmul ?

Both share the same core logic under the hood. I imagine you could express this pass as a transform sequence.
The pass combines the two transforms pack_greedily and pack_transpose to achieve more specific transformation that is beneficial for CPUs.

Could they be merged together ?

I think the two serve a bit different purposes.

@banach-space
Copy link
Contributor

banach-space commented Apr 23, 2024

For folks less familiar with TPP, could you compare this to rewriting linalg.matmul as tensor.pack + linalg.mmt4d + tensor.unpack?

And could you include pseudo-IR before and after in the summary?

@yzhang93
Copy link
Contributor

yzhang93 commented Apr 24, 2024

Thanks for the change. We have been doing something similar downstream as well.

@yzhang93 can you help review this and guide it based on the use case we have worked through.

I think this pass may not be general enough for a broader audience . It limits the usage to linalg.matmul/linalg.batch_matmul ops (without any linalg.generic variants) and fixes the innerPerm/outerPerm to {1, 0}.

In our downstream development, I found it very fixable to directly call linalg::pack and packTranspose to meet our particular requirements for dimension permutations and various generic versions of matmul. On the other hand, this pass looks like a wrapper for packMatmulGreedily and packTranspose methods but has a rather narrow scope.

@MaheshRavishankar
Copy link
Contributor

Thanks for the change. We have been doing something similar downstream as well.
@yzhang93 can you help review this and guide it based on the use case we have worked through.

I think this pass may not be general enough for a broader audience . It limits the usage to linalg.matmul/linalg.batch_matmul ops (without any linalg.generic variants) and fixes the innerPerm/outerPerm to {1, 0}.

In our downstream development, I found it very fixable to directly call linalg::pack and packTranspose to meet our particular requirements for dimension permutations and various generic versions of matmul. On the other hand, this pass looks like a wrapper for packMatmulGreedily and packTranspose methods but has a rather narrow scope.

What does packMatmulGreedily do? It is fine to start with something that has a narrow scope to begin with but can be generalized as needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that there's no other documentation attached to this PR, I'm using these tests to infer the high level logic. That's a bit tricky ATM - it could be simplified with more meaningful var names (MLIR and LIT) ;-)

For example:

func.func @block_matmul(%lhs: tensor<128x128xf32>, %rhs: tensor<128x128xf32>, %out: tensor<128x128xf32>) -> tensor<128x128xf32> {
  %0 = linalg.matmul  ins(%lhs, %rhs : tensor<128x128xf32>, tensor<128x128xf32>)
                      outs(%our : tensor<128x128xf32>) -> tensor<128x128xf32>
  return %0 : tensor<128x128xf32>
}

or

func.func @block_matmul(%A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
  %0 = linalg.matmul  ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
                      outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
  return %0 : tensor<128x128xf32>
}

And then in LIT:

// CHECK: %[[LHS_PACKED:.+]] = tensor.pack %[[LHS]]

Or something similar. I don't really mind what the actual names are, as long as they help to map to the high level semantics of Matmul. Updated names will really help to review this, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing, I'll also improve description as suggested.
An example with mmt4d should help a lot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improved the description and test naming. Hopefully, now they are more understandable.

@adam-smnk
Copy link
Contributor Author

I think this pass may not be general enough for a broader audience . It limits the usage to linalg.matmul/linalg.batch_matmul ops (without any linalg.generic variants) and fixes the innerPerm/outerPerm to {1, 0}.

I'll start with generalizing and exposing the API as Mahesh suggested. But I think there's still value in providing tools for the well defined named ops. Patterns for common linalg.generic variants could be added later too. But if you need to go with truly generic representations, you probably need custom logic anyway.

On the other hand, this pass looks like a wrapper for packMatmulGreedily and packTranspose methods but has a rather narrow scope.

The aim is not necessarily to create a fully generic infrastructure. After all there are existing tools for that. Although, if we find enough common patterns then even better.
I'd say the main motivation here is to add a transformation that provides benefit in many common cases. Perhaps, it is still too narrow and/or I haven't considered enough use cases to judge that properly.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for the contribution! This is exciting! A couple of comments:

  • Why are you framing this as a CPU-only transformation? Are there any implicit limitations that we should be aware of?
  • How configurable the layout would be? Could we do plain transposes, decide which operands to transpose and how many blocking levels to use? (Sorry if I missed something while skimming through the code)

}

static FailureOr<linalg::LinalgOp>
packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@adam-smnk
Copy link
Contributor Author

Why are you framing this as a CPU-only transformation? Are there any implicit limitations that we should be aware of?

AFAIK, it's not necessarily beneficial to pack for GPU to give a counter example. We've also used packing just for CPUs so, it's the main motivation or use case I can confidently bring up.
It might be equally useful for other targets which might use the same or a completely different packing layout/scheme.

How configurable the layout would be? Could we do plain transposes, decide which operands to transpose and how many blocking levels to use? (Sorry if I missed something while skimming through the code)

This one worked best based on our use cases. I'll try to generalize the pass and see what can be done.

@adam-smnk adam-smnk marked this pull request as draft April 25, 2024 14:56
}

if (mnkTiles.size() != 3)
return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
Copy link
Contributor

@yifeizh2 yifeizh2 May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation only takes 1 set of mnkTiles and try to pack all matmul ops with the given tile sizes.
Will we consider the use cases where we wish to pack different matmul ops into different tile sizes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added control callback that can specify packing sizes and layout configuration separately for each op. Hopefully, this addresses your use case.

@adam-smnk adam-smnk changed the title [mlir][linalg] Pack matmul pass [mlir][linalg] Block pack matmul pass May 7, 2024
@adam-smnk
Copy link
Contributor Author

Thanks for all the feedback everyone.
It allowed me to iterate on the design and create a more flexible transformation.

Changes:

  • renamed to 'block pack matmul' to clarify the main intent and avoid confusion with the existing pack transform
  • exposed transform's API outside of the pass
  • added control function - allows to specify packing sizes and desired layout on per operation basis
  • exposed more matmul packing options
  • enabled padding which can be optionally disabled
  • relaxed checks to allow dynamic shapes
  • improved block packing controls to focus on the desired block layout - this bring to the front user's intention and allows the transformation to take care of the inner and outer blocks transposition regardless of data layout which bring support for (batch) matmul with transposed a/b operands
  • defaulted transformation to mmt4d layout - it should make it easier to understand the transformation and available options
  • added PoC pass support for simple generics - probably more variants can be handled but this can be tested and added later

@adam-smnk adam-smnk marked this pull request as ready for review May 7, 2024 14:14
@adam-smnk
Copy link
Contributor Author

adam-smnk commented May 8, 2024

@MaheshRavishankar @yzhang93 Does the more generalized API fit your use cases better now?

There are surely still limitation when to comes to more complex generic or arbitrary layouts. Hopefully the API is flexible enough to allow future extensions for the former. The latter might be out of scope here.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable to me, except for having a pass. Ill approve it but would really appreciate if we drop the pass from core and move it to tests.

@@ -141,4 +141,63 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
];
}

def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not going to hold on to it, but I really dislike having such passes. I think this is just dead weight, and can never fit all downstream uses, and really we should need just transformation methods/patterns and let the passes live downstream and just have a test pass for testing. Such passes in core do not age well.

Copy link
Member

@rengolin rengolin May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such passes in core do not age well.

That's true, but I think it's orthogonal. We need the functionality upstream so that we can all use and share, and it will only age if no one uses. Being in a test pass or a transform or a dialect pass won't change much that equation.

I'd love if IREE and other compilers could bring more functionality to this pass and start using the upstream stuff (parametrized, cost modelled, etc). This is what makes them not age. If we don't provide a way for people to use in their projects, then MLIR gets less "batteries" (as @joker-eph usually say).

@rengolin rengolin merged commit 4c3db25 into llvm:main May 9, 2024
@banach-space
Copy link
Contributor

Thanks again for working on this!

This landed within an hour of Mahesh approving ... Could you allow a bit of extra time for other reviewers to take a look and to rubber stamp it? ;-) Thank you 🙏🏻

@rengolin
Copy link
Member

rengolin commented May 9, 2024

Well, the PR is 2 weeks old and the last update was 2 days ago, so there was plenty of time for reviews.

LLVM doesn't really have the concept of rubber stamping, one approval is enough, so what we do is just iterate in-tree. Post commit reviews are pretty common.

It'd be great if you can try it out and bring issues, PRs and ideas from your side. We want this functionality to work for everyone!

@banach-space
Copy link
Contributor

banach-space commented May 9, 2024

Well, the PR is 2 weeks old and the last update was 2 days ago

This PR was marked as draft for most of that time, so I assumed that there was no rush and also that the PR wasn't ready for review.

From LLVM docs https://llvm.org/docs/CodeReview.html#lgtm-how-a-patch-is-accepted:

If approval is received very quickly, a patch author may also elect to wait before committing (and this is certainly considered polite for non-trivial patches). Especially given the global nature of our community, this waiting time should be at least 24 hours.

There's many ways to read it. For a bigger change like this one, I really appreciate when people leave ~24hrs between the first approval and landing in-tree. Not a formal requirement nor policy, just a kind request.

Please don't take this the wrong way, I am OK with this change and support the effort. All my comments have been addressed.

EDIT

It'd be great if you can try it out and bring issues, PRs and ideas from your side.

From the initial submission it wasn't clear to me that this was targeting linalg.mmt4d. We've added some e2e tests recently for mmt4d:

We are also working on enabling this for scalable vectors. Here's our meta-ticket in IREE (the link will take you to a mini RFC in that ticket):

Admittedly, our contributions might be less relevant to folks focusing on fixed-width vectors.

@rengolin
Copy link
Member

rengolin commented May 9, 2024

This is not related (yet) to mmt4d not scalable vectors, so I guess it should be fine. We want to get the packing propagation and constant folding next, which needs this one through, so we'll add you on the following PRs to make sure we get your review.

As discussed above, if we have more users (you, IREE, us), then the code will continue relevant, so it'd be really good to get your feedback and hopefully we can adapt to your usage as well as ours.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants