Skip to content

[mlir][linalg] Add patterns to convert matmul to transposed variants #89075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

c-rhodes
Copy link
Collaborator

@c-rhodes c-rhodes commented Apr 17, 2024

Our work enabling a lowering path from linalg.matmul to ArmSME has revealed the current lowering results in non-contiguous memory accesses for the A matrix and very poor performance.

These patterns provide a simple option to fix this.

This patch introduces a pass `-linalg-matmul-to-matmul-transpose-a`,
which transposes the A matrix of a Linalg matmul operation, with the aim
of memory accesses being contiguous.

Our work enabling a lowering path from `linalg.matmul` to ArmSME has
revealed the current lowering results in non-contiguous memory accesses
for the A matrix and very poor performance.

This pass provides a simple option to fix this.
@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Cullen Rhodes (c-rhodes)

Changes

This patch introduces a pass -linalg-matmul-to-matmul-transpose-a, which transposes the A matrix of a Linalg matmul operation with the aim of memory accesses being contiguous.

Our work enabling a lowering path from linalg.matmul to ArmSME has revealed the current lowering results in non-contiguous memory accesses for the A matrix and very poor performance.

This pass provides a simple option to fix this.


Full diff: https://github.com/llvm/llvm-project/pull/89075.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+9)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp (+92)
  • (added) mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir (+76)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..38be6e49f574c9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,13 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
+def LinalgMatmulToMatmulTransposeAPass
+    : Pass<"linalg-matmul-to-matmul-transpose-a"> {
+  let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a`.";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let description = [{
+    Transposes the A matrix of a `linalg.matmul` for contiguous access.
+  }];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index feb3b3f03cf538..0d354c666b1742 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,6 +1616,9 @@ void populateSplitReductionPattern(
     const ControlSplitReductionFn &controlSplitReductionFn,
     bool useAlloc = false);
 
+/// Pattern to replace `linalg.matmul` with `linalg.matmul_transpose_a`.
+void populateMatmulToMatmulTransposeAPattern(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 513c54de5d7bfc..bca4954f959da3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
+  MatmulToMatmulTransposeA.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
   Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
new file mode 100644
index 00000000000000..45551cd9167b60
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
@@ -0,0 +1,92 @@
+//===- MatmulToMatmulTransposeA.cpp - Linalg matmul to matmul_transpose_a -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This rewrite and pass transposes the A matrix of a `linalg.matmul` operation
+// with the aim of the memory accesses becoming contiguous.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMATMULTOMATMULTRANSPOSEAPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-matmul-to-matmul-transpose-a"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Pattern to replace `linalg.matmul(a, b)` with
+/// `linalg.matmul_transpose_a(linalg.transpose(a), b)`.
+struct MatmulToMatmulTransposeA final
+    : public OpRewritePattern<linalg::MatmulOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+                                PatternRewriter &rewriter) const override {
+    if (!bufferization::hasTensorSemantics(matmulOp))
+      return rewriter.notifyMatchFailure(
+          matmulOp, "only matmul ops with tensors are supported");
+
+    Value a = matmulOp.getInputs()[0];
+    auto aType = cast<ShapedType>(a.getType());
+    if (aType.getRank() != 2)
+      return rewriter.notifyMatchFailure(
+          matmulOp, "only 2-D matmul ops are supported");
+
+    Location loc = matmulOp.getLoc();
+
+    SmallVector<Value> dynamicDims;
+    if (aType.isDynamicDim(1))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 1));
+    if (aType.isDynamicDim(0))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 0));
+
+    auto aShape = aType.getShape();
+    SmallVector<int64_t> transposedShape{aShape[1], aShape[0]};
+    Value empty = rewriter.create<tensor::EmptyOp>(
+        loc, transposedShape, aType.getElementType(), dynamicDims);
+    static constexpr std::array<int64_t, 2> perm = {1, 0};
+    auto transposeAOp =
+        rewriter.create<linalg::TransposeOp>(loc, a, empty, perm);
+    rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
+        matmulOp, matmulOp.getResultTypes(),
+        ValueRange{transposeAOp->getResult(0), matmulOp.getInputs()[1]},
+        matmulOp.getOutputs());
+
+    return success();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateMatmulToMatmulTransposeAPattern(
+    RewritePatternSet &patterns) {
+  patterns.add<MatmulToMatmulTransposeA>(patterns.getContext());
+}
+
+namespace {
+struct LinalgMatmulToMatmulTransposeAPass
+    : public impl::LinalgMatmulToMatmulTransposeAPassBase<
+          LinalgMatmulToMatmulTransposeAPass> {
+  using impl::LinalgMatmulToMatmulTransposeAPassBase<
+      LinalgMatmulToMatmulTransposeAPass>::
+      LinalgMatmulToMatmulTransposeAPassBase;
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateMatmulToMatmulTransposeAPattern(patterns);
+    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
+  }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
new file mode 100644
index 00000000000000..c3c8dff98ba3c9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt -linalg-matmul-to-matmul-transpose-a -cse -canonicalize -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @static(
+// CHECK-SAME:                      %[[A:.*]]: tensor<16x8xf32>,
+// CHECK-SAME:                      %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
+// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK:           return %[[C]] : tensor<16x16xf32>
+// CHECK:         }
+func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %init = tensor.empty() : tensor<16x16xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<16x16xf32>) -> tensor<16x16xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<16x8xf32>, tensor<8x16xf32>) outs(%C : tensor<16x16xf32>) -> tensor<16x16xf32>
+  return %0 : tensor<16x16xf32>
+}
+
+//-----
+
+// CHECK-LABEL:   func.func @dynamic(
+// CHECK-SAME:                       %[[A:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:                       %[[B:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
+// CHECK:           %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
+// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[B_DIM1]]) : tensor<?x?xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:           %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:           return %[[C]] : tensor<?x?xf32>
+// CHECK:         }
+func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %B, %c1 : tensor<?x?xf32>
+  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+//-----
+
+// CHECK-LABEL:   func.func @mixed(
+// CHECK-SAME:                     %[[A:.*]]: tensor<?x8xf32>,
+// CHECK-SAME:                     %[[B:.*]]: tensor<8x16xf32>) -> tensor<?x16xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32>
+// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<?x16xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
+// CHECK:           %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK:           return %[[B0]] : tensor<?x16xf32>
+// CHECK:         }
+func.func @mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %A, %c0 : tensor<?x8xf32>
+  %init = tensor.empty(%d0) : tensor<?x16xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x16xf32>) -> tensor<?x16xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<?x8xf32>, tensor<8x16xf32>) outs(%C : tensor<?x16xf32>) -> tensor<?x16xf32>
+  return %0 : tensor<?x16xf32>
+}

Copy link

github-actions bot commented Apr 17, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @c-rhodes . This seems like a very specific use case though.... The possibilities here are too many. For one we shouldnt add a pass for it at all...

Broadly, I am wondering if we even need to add such one-off patterns in core. If there is a "general framework" to convert from one named op type to another prescriptively that would be a useful infra to have, but a one-off doesnt seem useful and hard to manage.

@c-rhodes
Copy link
Collaborator Author

Thanks @c-rhodes . This seems like a very specific use case though.... The possibilities here are too many. For one we shouldnt add a pass for it at all...

Broadly, I am wondering if we even need to add such one-off patterns in core. If there is a "general framework" to convert from one named op type to another prescriptively that would be a useful infra to have, but a one-off doesnt seem useful and hard to manage.

Thanks for comments @MaheshRavishankar. I can't say for sure this is relevant outside of our use case (SME), but the non-contiguous reads occur at quite a high level (Linalg / Vector) and other targets must face similar issues. To be honest I wasn't sure about this approach or how well it would be received, but this seemed like the simplest option.

I know there's MMT4D and @banach-space has spent a bit of time looking at this but it's a bit more involved and will require some work. I also came across the CanonicalizeContractMatmulToMMT pattern in VectorTransforms whilst implementing this pass that was originally in GPU lowerings (+ CombineTransferReadOpTranspose) and appears to be solving similar problem, but I've yet to look closely into this.

If some generic targeted named op conversion infra as you suggest would be useful we could look into this.

@MacDue
Copy link
Member

MacDue commented Apr 18, 2024

Thanks @c-rhodes . This seems like a very specific use case though.... The possibilities here are too many. For one we shouldnt add a pass for it at all...

Broadly, I am wondering if we even need to add such one-off patterns in core. If there is a "general framework" to convert from one named op type to another prescriptively that would be a useful infra to have, but a one-off doesnt seem useful and hard to manage.

Note that the main use of this will be in IREE (for ArmSME). Would you be okay with adding a one-off pattern/rewrite there (rather than in the core)?

@banach-space
Copy link
Contributor

Hey @c-rhodes , thanks for implementing this!

This seems like a very specific use case though....

ArmSME is just one motivating example. The applicability of this is much wider - it's a very generic transformation:

  • linalg.matmul -> linalg.transpose + linalg.matmul_transpose_a.

There are other examples like this in tree (linalg.conv -> linalg.transpose + linalg.conv):

Here's an example of wrapping a handful of patterns into a pass (in case that's the actual concern here):

Also, an example of a rather target-specific set of optimisations:

I'm saying "target-specific" as a chunk of that won't work for scalable vectors (so one could argue that it's not "generic enough").

I know there's MMT4D and @banach-space has spent a bit of time looking at this but it's a bit more involved and will require some work.

Even when we do have MMT4D, this transformation will remain a useful stepping stone that I'd like to use for benchmarking.

@MaheshRavishankar , to me this change is consistent with other transformations that we have in MLIR. I also don't see it being a maintenance burden - it's rather tiny.

If there is a "general framework" to convert from one named op type to another prescriptively that would be a useful infra to have, but a one-off doesnt seem useful and hard to manage.

It's not obvious to me that that would be better than what's proposed here. Isn't that what patterns+passes are for? Do you have something specific in mind? Any references? Thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Cullen!

I suggest that we generalise this a bit and allow 2 variants:

  1. linalg.matmul(A, B) --> linalg.transpose(A) + linalg.matmul_tranpose_a(A, B)
  2. linalg.matmul(A, B) --> linalg.transpose(B) + linalg.matmul_tranpose_b(A, B)

Otherwise this transformation would only be beneficial for platforms lowering linalg.matmul to outer-products (e.g. ArmSME). We should keep it more general.

@c-rhodes
Copy link
Collaborator Author

LGTM, thanks Cullen!

I suggest that we generalise this a bit and allow 2 variants:

1. `linalg.matmul(A, B)` --> `linalg.transpose(A) + linalg.matmul_tranpose_a(A, B)`

2. `linalg.matmul(A, B)` --> `linalg.transpose(B) + linalg.matmul_tranpose_b(A, B)`

Otherwise this transformation would only be beneficial for platforms lowering linalg.matmul to outer-products (e.g. ArmSME). We should keep it more general.

Thanks Andrzej. I've added a flag to control which input gets transposed.

@c-rhodes c-rhodes changed the title [mlir][linalg] Add pass to transpose A matrix of matmul op [mlir][linalg] Add pass to matmul op Apr 18, 2024
@c-rhodes c-rhodes changed the title [mlir][linalg] Add pass to matmul op [mlir][linalg] Add pass to transpose matmul op Apr 18, 2024
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any major concerns... We have linalg.matmul and linalg.matmul_transpose_a/b operations so having patterns to go from one another seem something basic to have (we may also want to add the batch variants).

Having a simple pass to exercise them should also be ok as long as we set the expectations properly (i.e., this is expected to be a simple matmul transposition and not going into specific target-dependent tuning, complex data layout transformations, etc.)

@c-rhodes
Copy link
Collaborator Author

I don't see any major concerns... We have linalg.matmul and linalg.matmul_transpose_a/b operations so having patterns to go from one another seem something basic to have (we may also want to add the batch variants).

Having a simple pass to exercise them should also be ok as long as we set the expectations properly (i.e., this is expected to be a simple matmul transposition and not going into specific target-dependent tuning, complex data layout transformations, etc.)

Thanks for reviewing. I'll add a comment to the pass the set expectations 👍

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing my comments!

@MaheshRavishankar
Copy link
Contributor

I'll explain more later, but you could do it as a PDL/pdll/TD script to do this replacement as a preprocessing. I know vector dialect does some things, but it's has been a long running issue that charting a path through vector dialect is challenging. I wouldn't use that as an example.
This is a fairly straightforward pattern that can live anywhere downstream but as a one off pattern is strange.

@dcaballe
Copy link
Contributor

Not sure I follow the concern: is it about the pattern or the pass? I don't see any reason for the patterns to live downstream. Not sure I follow the one-off argument. We have plenty of patterns ranging from very simple to complex and collective transformations. The propose patterns are useful in general to transform an existing op into other existing ops and will include different variants (standard matmul, batch and even vecmat/matvec) so I don't see any kind of special casing here and do see applicability even beyond SME (e.g., in-place matmul transposes for a non-DT approach).

If the concern is about the pass, I guess we could turn the pass into a "test" pass and let users add these patterns to their downstream passes but, again, I think the pass is perfectly valid as it's doing a well defined set of transformations. We can also consider adding a TD op but I don't think we have been enforcing one alternative versus the other. Both options have been successfully co-existing and been added in a as-needed basis and I don't think there is a reason in this PR to follow a different approach.

@banach-space
Copy link
Contributor

This is a fairly straightforward pattern that can live anywhere downstream

All our development happens upstream - that's specifically to make sure that our design aligns with the upstream requirements and overall guidelines for contributing to MLIR. Also, we want to make this compiler technology available to everyone using MLIR core rather than requiring people to use forks. To this end, we are keen to adapt our design based on the comments that we receive from our reviewers - we really appreciate all the guidance and feedback 🙏🏻

However, if this is more about "what belongs upstream" vs "what belongs downstream", then it sounds like a more general discussion on MLIR/Linalg/Vector. I am not aware of any guidelines on what qualifies for an "upstream" contribution to Linalg. If we go by "past precedence", then I feel that the examples that I shared above clearly indicate that this contribution qualifies. Otherwise, we should probably bring this up on Discourse and see how others feel? More specifically ...

as a one off pattern is strange

IMHO, this is very subjective. Two separate reviewers have confirmed that they'd like to use this (myself and Diego). Surely having one implementation upstream is better than having 2 or more implementations downstream?

As for ...

you could do it as a PDL/pdll/TD script to do this replacement as a preprocessing

... yes, every transformation can be driven in multiple ways (e.g. TD vs C++). But as Diego points out:

I don't think we have been enforcing one alternative versus the other

AFAIK, there's no one "right" way of driving things in MLIR/Linalg. As a reviewer, I always try to be accommodating - if a certain approach works better for a particular project then we should support that. If that helps people adapt MLIR/Linalg then that's beneficial for all of us! IMHO such flexibility is crucial for the success of MLIR/Linalg.

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't personally have a strong opinion on how/where this transform is implemented, so looking just at the code this LGTM. I have some very minor nits (that apply to both rewrites), but feel free to ignore those.

@MaheshRavishankar
Copy link
Contributor

We are going in circles here. I understand that this use case is important, and accomodating downstream users is important to me as well. But IMO we need to balance what goes into core and what lives downstream.

First of all, we cant have a pass. Passes in MLIR IMO are mostly for testing. So agreed there we can have a test pass (or as is the norm now a test transform op).

Broadly though, my concern is what are all the shapes of these kind of patterns we want to accomodate, and how to do you manage how these are used.
This is adding matmul -> transpose + matmul_transpose_a + transpose
What about matmul -> transpose + matmul_transpose_b + transpose
What about reverse of these (which are also reasonable) transpose + matmul_transpose_b + transpose -> matmul and other variants. There is no bound to these kind of patterns and no story on how and when to use them. The justifications for all of this depends on downstream use cases, so that is where such patterns should live. I dont see this as a "core infrastructure" to enable building a MLIR based compiler, but rather just a particular use case.

@dcaballe
Copy link
Contributor

This is adding matmul -> transpose + matmul_transpose_a + transpose
What about matmul -> transpose + matmul_transpose_b + transpose
What about reverse of these (which are also reasonable) transpose + matmul_transpose_b + transpose -> matmul and other variants.

I think all those transformations are reasonably valid and should have a place upstream in the same way we have patterns that transform a matmul into an outer product, and fma a multi-reduction, some LLVM matrix intrinsics and what not. It's all a matter having certain level of optionality. Different users have different needs and I don't think each and every one of them should reinvent these patterns downstream. We have found ourselves in situations where having this optionality allowed us to quickly explore different implementation options that we initially didn't think would be optimal for our use cases.

Also, if this is needed for something like SVE (and potentially any outer product engine, really), it seems relevant enough to me to have it upstream, right?

@MaheshRavishankar
Copy link
Contributor

MaheshRavishankar commented Apr 20, 2024

I think all those transformations are reasonably valid and should have a place upstream in the same way we have patterns that transform a matmul into an outer product, and fma a multi-reduction, some LLVM matrix intrinsics and what not. It's all a matter having certain level of optionality. Different users have different needs and I don't think each and every one of them should reinvent these patterns downstream. We have found ourselves in situations where having this optionality allowed us to quickly explore different implementation options that we initially didn't think would be optimal for our use cases.

This is mixing a lot of concepts. If your talking about Vector dialect, then maybe your experience and my experience is very different. The patterns there suffer from a lot of complementary options which only makes sense when put together in a narrow way. The vector dialect transformations patterns dont seem to have a coherent transformation sequence that someone can navigate for their use case. My concern is adding just one off patterns without a way to manage them effectively would just mean we end up with a whole bunch of patterns, some doing inverse of the other and "populate*" methods that no one can really navigate.

Also, if this is needed for something like SVE (and potentially any outer product engine, really), it seems relevant enough to me to have it upstream, right?

It might be needed for SVE, but is being done on linalg level that doesnt directly map to SVE. MLIR only has a limited end to end examples. These patterns only make sense in that context, and without tying it to end-to-end examples they are just one-off and hard to rationalize or find. That is why
(a) either these live downstream, or
(b) if there is an e2e pipeline being built as an example in MLIR, then these should live scoped to these pipelines and not as a pattern somewhere in core that someone without the context of the e2e pipeline can make sense of as to "why is it here".

- address Ben's nits.
- replace pass with transform op.
@c-rhodes c-rhodes requested a review from ftynse as a code owner April 22, 2024 08:41
@c-rhodes c-rhodes changed the title [mlir][linalg] Add pass to transpose matmul op [mlir][linalg] Add patterns to convert matmul to transposed variants Apr 22, 2024
@c-rhodes
Copy link
Collaborator Author

We are going in circles here. I understand that this use case is important, and accomodating downstream users is important to me as well. But IMO we need to balance what goes into core and what lives downstream.

First of all, we cant have a pass. Passes in MLIR IMO are mostly for testing. So agreed there we can have a test pass (or as is the norm now a test transform op).

FWIW I've removed the pass in favour of a transform op.

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no concerns with this now that it is implemented as an apply_patterns op.
I would like to see this evolve to a more general set of rewrites but see no reason to disallow: it's a go on my end.

@ftynse
Copy link
Member

ftynse commented Apr 22, 2024

I have no concerns with this now that it is implemented as an apply_patterns op.

OTOH, I'd prefer this to be more targeted than a blanket "apply everywhere" pattern, but there is always time to add that later.

@c-rhodes
Copy link
Collaborator Author

I have no concerns with this now that it is implemented as an apply_patterns op.

OTOH, I'd prefer this to be more targeted than a blanket "apply everywhere" pattern, but there is always time to add that later.

Do you mean to match on a specific matmul operation?

@ftynse
Copy link
Member

ftynse commented Apr 22, 2024

Do you mean to match on a specific matmul operation?

I mean take a handle to the matmul operation to transpose. Where it comes from is orthogonal to the transformation. I'm happy to have it in this commit or separately.

Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with revised approach as it goes with the grain of the others and can be expanded in the future.

@c-rhodes
Copy link
Collaborator Author

Do you mean to match on a specific matmul operation?

I mean take a handle to the matmul operation to transpose. Where it comes from is orthogonal to the transformation. I'm happy to have it in this commit or separately.

👍 I'll address that in a follow-up

@MaheshRavishankar MaheshRavishankar dismissed their stale review April 22, 2024 16:44

Dropping my hold since there is community support for this

@c-rhodes c-rhodes merged commit 7922534 into llvm:main Apr 23, 2024
@c-rhodes c-rhodes deleted the mlir-linalg-matmul-to-matmul-transpose-a-pass branch April 23, 2024 06:21
@c-rhodes
Copy link
Collaborator Author

Do you mean to match on a specific matmul operation?

I mean take a handle to the matmul operation to transpose. Where it comes from is orthogonal to the transformation. I'm happy to have it in this commit or separately.

👍 I'll address that in a follow-up

#89717

c-rhodes added a commit that referenced this pull request Apr 23, 2024
More targeted than a blanket "apply everywhere" pattern. Follow up to
#89075 to address @ftynse's feedback.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants