Skip to content

[mlir][linalg] Move transpose_matmul to targeted transform op #89717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

c-rhodes
Copy link
Collaborator

More targeted than a blanket "apply everywhere" pattern. Follow up to #89075 to address @ftynse's feedback.

More targeted than a blanket "apply everywhere" pattern. Follow up to llvm#89075
to address @ftynse's feedback.
@llvmbot
Copy link
Member

llvmbot commented Apr 23, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Cullen Rhodes (c-rhodes)

Changes

More targeted than a blanket "apply everywhere" pattern. Follow up to #89075 to address @ftynse's feedback.


Full diff: https://github.com/llvm/llvm-project/pull/89717.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+46-17)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+8)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+28-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+98-77)
  • (modified) mlir/test/Dialect/Linalg/transpose-matmul-a.mlir (+2-3)
  • (modified) mlir/test/Dialect/Linalg/transpose-matmul-b.mlir (+2-3)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index beb4cb076f4947..d0ad4ccdf031d9 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,23 +73,6 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-def ApplyTransposeMatmulPatternsOp : Op<Transform_Dialect,
-    "apply_patterns.linalg.transpose_matmul",
-    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
-  let description = [{
-    Collects patterns to convert Linalg matmul ops to transposed variants.
-
-    By default the LHS matrix is transposed. Set `inputToTranspose=<rhs>` to
-    instead transpose RHS matrix.
-  }];
-
-  let arguments = (ins
-    DefaultValuedAttr<TransposeMatmulInput,
-                      "TransposeMatmulInput::lhs">:$inputToTranspose);
-
-  let assemblyFormat = "(`<` $inputToTranspose^ `>`)? attr-dict";
-}
-
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
@@ -2429,6 +2412,52 @@ def TransposeConv2DOp : Op<Transform_Dialect,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TransposeMatmulOp
+//===----------------------------------------------------------------------===//
+
+def TransposeMatmulOp : Op<Transform_Dialect,
+    "structured.transpose_matmul",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Convert Linalg matmul ops to transposed variants.
+
+    By default the LHS matrix is transposed. Specify `<rhs>` to instead
+    transpose RHS matrix.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported, i.e., not a
+    `linalg.matmul` or `linalg.batch_matmul`. Otherwise, the operation succeeds
+    and returns a handle to the transposed matmul op.
+  }];
+
+  let arguments = (ins
+    TransformHandleTypeInterface:$target,
+    DefaultValuedAttr<TransposeMatmulInput,
+                      "TransposeMatmulInput::lhs">:$inputToTranspose);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat = [{
+    $target (`<` $inputToTranspose^ `>`)?
+    attr-dict `:` functional-type($target, results)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 3bee911ca282ea..5ecf84fa9c7012 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1244,6 +1244,14 @@ FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
 FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
                                        linalg::Conv2DNhwcFhwcQOp op);
 
+/// Convert Linalg matmul ops to transposed variants.
+FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter,
+                                       linalg::MatmulOp op,
+                                       bool transposeLHS = true);
+FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
+                                            linalg::BatchMatmulOp op,
+                                            bool transposeLHS = true);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8f1faa83cbb9cc..b4463c1912d518 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -199,12 +199,6 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
   linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
 }
 
-void transform::ApplyTransposeMatmulPatternsOp::populatePatterns(
-    RewritePatternSet &patterns) {
-  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
-  linalg::populateTransposeMatmulPatterns(patterns, transposeLHS);
-}
-
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
@@ -3422,6 +3416,34 @@ DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TransposeMatmulOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::MatmulOp op) {
+            return transposeMatmul(rewriter, op, transposeLHS);
+          })
+          .Case([&](linalg::BatchMatmulOp op) {
+            return transposeBatchMatmul(rewriter, op, transposeLHS);
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+  // Handle to the new Matmul operation with transposed filters
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index a4a05b243ad2b4..aa0052ce47fa7b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -18,7 +18,6 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-namespace {
 /// Pattern to replace
 ///
 ///   linalg.matmul(a, b)
@@ -29,44 +28,107 @@ namespace {
 ///
 /// By default the LHS is transposed. Set `transposeLHS=false` to
 /// transpose RHS instead.
+FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
+                                                     linalg::MatmulOp matmulOp,
+                                                     bool transposeLHS) {
+  if (!bufferization::hasTensorSemantics(matmulOp))
+    return rewriter.notifyMatchFailure(
+        matmulOp, "only matmul ops with tensors are supported");
+
+  Location loc = matmulOp.getLoc();
+  Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
+  auto type = cast<ShapedType>(input.getType());
+
+  SmallVector<Value> dynamicDims;
+  if (type.isDynamicDim(1))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+  if (type.isDynamicDim(0))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+  ArrayRef<int64_t> shape = type.getShape();
+  Value empty = rewriter.create<tensor::EmptyOp>(
+      loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
+      dynamicDims);
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, input, empty, ArrayRef<int64_t>{1, 0});
+  Operation *newMatmulOp;
+  if (transposeLHS) {
+    newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
+        loc, matmulOp.getResultTypes(),
+        ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
+        matmulOp.getOutputs());
+  } else {
+    newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
+        loc, matmulOp.getResultTypes(),
+        ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
+        matmulOp.getOutputs());
+  }
+  rewriter.replaceOp(matmulOp, newMatmulOp);
+  return newMatmulOp;
+}
+
+/// Pattern to replace
+///
+///   linalg.batch_matmul(a, b)
+///
+/// with
+///
+///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+///
+/// Only the non-batch dimensions are transposed. By default the LHS is
+/// transposed. Set `transposeLHS=false` to transpose RHS instead.
+FailureOr<Operation *>
+mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
+                                   linalg::BatchMatmulOp batchMatmulOp,
+                                   bool transposeLHS) {
+  if (!bufferization::hasTensorSemantics(batchMatmulOp))
+    return rewriter.notifyMatchFailure(
+        batchMatmulOp, "only matmul ops with tensors are supported");
+
+  Location loc = batchMatmulOp.getLoc();
+  Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
+  auto type = cast<ShapedType>(input.getType());
+
+  SmallVector<Value> dynamicDims;
+  if (type.isDynamicDim(0))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+  if (type.isDynamicDim(2))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
+  if (type.isDynamicDim(1))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+
+  ArrayRef<int64_t> shape = type.getShape();
+  Value empty = rewriter.create<tensor::EmptyOp>(
+      loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
+      type.getElementType(), dynamicDims);
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
+  Operation *newMatmulOp;
+  if (transposeLHS) {
+    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
+        loc, batchMatmulOp.getResultTypes(),
+        ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
+        batchMatmulOp.getOutputs());
+  } else {
+    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
+        loc, batchMatmulOp.getResultTypes(),
+        ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
+        batchMatmulOp.getOutputs());
+  }
+  rewriter.replaceOp(batchMatmulOp, newMatmulOp);
+  return newMatmulOp;
+}
+
+namespace {
 struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
   TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
       : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
 
-  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+  LogicalResult matchAndRewrite(linalg::MatmulOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!bufferization::hasTensorSemantics(matmulOp))
-      return rewriter.notifyMatchFailure(
-          matmulOp, "only matmul ops with tensors are supported");
-
-    Location loc = matmulOp.getLoc();
-    Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
-    auto type = cast<ShapedType>(input.getType());
-
-    SmallVector<Value> dynamicDims;
-    if (type.isDynamicDim(1))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
-    if (type.isDynamicDim(0))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
-
-    ArrayRef<int64_t> shape = type.getShape();
-    Value empty = rewriter.create<tensor::EmptyOp>(
-        loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
-        dynamicDims);
-    auto transposeOp = rewriter.create<linalg::TransposeOp>(
-        loc, input, empty, ArrayRef<int64_t>{1, 0});
-    if (transposeLHS) {
-      rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
-          matmulOp, matmulOp.getResultTypes(),
-          ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
-          matmulOp.getOutputs());
-    } else {
-      rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
-          matmulOp, matmulOp.getResultTypes(),
-          ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
-          matmulOp.getOutputs());
+    if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
+      return failure();
     }
-
     return success();
   }
 
@@ -74,57 +136,16 @@ struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
   bool transposeLHS;
 };
 
-/// Pattern to replace
-///
-///   linalg.batch_matmul(a, b)
-///
-/// with
-///
-///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
-///
-/// Only the non-batch dimensions are transposed. By default the LHS is
-/// transposed. Set `transposeLHS=false` to transpose RHS instead.
 struct TransposeBatchMatmul final
     : public OpRewritePattern<linalg::BatchMatmulOp> {
   TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
       : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
 
-  LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
+  LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!bufferization::hasTensorSemantics(batchMatmulOp))
-      return rewriter.notifyMatchFailure(
-          batchMatmulOp, "only matmul ops with tensors are supported");
-
-    Location loc = batchMatmulOp.getLoc();
-    Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
-    auto type = cast<ShapedType>(input.getType());
-
-    SmallVector<Value> dynamicDims;
-    if (type.isDynamicDim(0))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
-    if (type.isDynamicDim(2))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
-    if (type.isDynamicDim(1))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
-
-    ArrayRef<int64_t> shape = type.getShape();
-    Value empty = rewriter.create<tensor::EmptyOp>(
-        loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
-        type.getElementType(), dynamicDims);
-    auto transposeOp = rewriter.create<linalg::TransposeOp>(
-        loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
-    if (transposeLHS) {
-      rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
-          batchMatmulOp, batchMatmulOp.getResultTypes(),
-          ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
-          batchMatmulOp.getOutputs());
-    } else {
-      rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
-          batchMatmulOp, batchMatmulOp.getResultTypes(),
-          ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
-          batchMatmulOp.getOutputs());
+    if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
+      return failure();
     }
-
     return success();
   }
 
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir b/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
index 1d2460f5467a5d..b1f33cfa56327e 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
@@ -2,10 +2,9 @@
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.transpose_matmul %matmul : (!transform.any_op) -> (!transform.any_op)
     %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %0 {
-      transform.apply_patterns.linalg.transpose_matmul
-    } : !transform.any_op
     transform.apply_cse to %0 : !transform.any_op
     transform.apply_patterns to %0 {
       transform.apply_patterns.canonicalization
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir b/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
index eecd76f1ecca7d..41e64c04dc6e59 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
@@ -2,10 +2,9 @@
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.transpose_matmul %matmul <rhs> : (!transform.any_op) -> (!transform.any_op)
     %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %0 {
-      transform.apply_patterns.linalg.transpose_matmul <rhs>
-    } : !transform.any_op
     transform.apply_cse to %0 : !transform.any_op
     transform.apply_patterns to %0 {
       transform.apply_patterns.canonicalization

@llvmbot
Copy link
Member

llvmbot commented Apr 23, 2024

@llvm/pr-subscribers-mlir

Author: Cullen Rhodes (c-rhodes)

Changes

More targeted than a blanket "apply everywhere" pattern. Follow up to #89075 to address @ftynse's feedback.


Full diff: https://github.com/llvm/llvm-project/pull/89717.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+46-17)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+8)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+28-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+98-77)
  • (modified) mlir/test/Dialect/Linalg/transpose-matmul-a.mlir (+2-3)
  • (modified) mlir/test/Dialect/Linalg/transpose-matmul-b.mlir (+2-3)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index beb4cb076f4947..d0ad4ccdf031d9 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,23 +73,6 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-def ApplyTransposeMatmulPatternsOp : Op<Transform_Dialect,
-    "apply_patterns.linalg.transpose_matmul",
-    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
-  let description = [{
-    Collects patterns to convert Linalg matmul ops to transposed variants.
-
-    By default the LHS matrix is transposed. Set `inputToTranspose=<rhs>` to
-    instead transpose RHS matrix.
-  }];
-
-  let arguments = (ins
-    DefaultValuedAttr<TransposeMatmulInput,
-                      "TransposeMatmulInput::lhs">:$inputToTranspose);
-
-  let assemblyFormat = "(`<` $inputToTranspose^ `>`)? attr-dict";
-}
-
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
@@ -2429,6 +2412,52 @@ def TransposeConv2DOp : Op<Transform_Dialect,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TransposeMatmulOp
+//===----------------------------------------------------------------------===//
+
+def TransposeMatmulOp : Op<Transform_Dialect,
+    "structured.transpose_matmul",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Convert Linalg matmul ops to transposed variants.
+
+    By default the LHS matrix is transposed. Specify `<rhs>` to instead
+    transpose RHS matrix.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported, i.e., not a
+    `linalg.matmul` or `linalg.batch_matmul`. Otherwise, the operation succeeds
+    and returns a handle to the transposed matmul op.
+  }];
+
+  let arguments = (ins
+    TransformHandleTypeInterface:$target,
+    DefaultValuedAttr<TransposeMatmulInput,
+                      "TransposeMatmulInput::lhs">:$inputToTranspose);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat = [{
+    $target (`<` $inputToTranspose^ `>`)?
+    attr-dict `:` functional-type($target, results)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 3bee911ca282ea..5ecf84fa9c7012 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1244,6 +1244,14 @@ FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
 FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
                                        linalg::Conv2DNhwcFhwcQOp op);
 
+/// Convert Linalg matmul ops to transposed variants.
+FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter,
+                                       linalg::MatmulOp op,
+                                       bool transposeLHS = true);
+FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
+                                            linalg::BatchMatmulOp op,
+                                            bool transposeLHS = true);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8f1faa83cbb9cc..b4463c1912d518 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -199,12 +199,6 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
   linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
 }
 
-void transform::ApplyTransposeMatmulPatternsOp::populatePatterns(
-    RewritePatternSet &patterns) {
-  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
-  linalg::populateTransposeMatmulPatterns(patterns, transposeLHS);
-}
-
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
@@ -3422,6 +3416,34 @@ DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TransposeMatmulOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::MatmulOp op) {
+            return transposeMatmul(rewriter, op, transposeLHS);
+          })
+          .Case([&](linalg::BatchMatmulOp op) {
+            return transposeBatchMatmul(rewriter, op, transposeLHS);
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+  // Handle to the new Matmul operation with transposed filters
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index a4a05b243ad2b4..aa0052ce47fa7b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -18,7 +18,6 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-namespace {
 /// Pattern to replace
 ///
 ///   linalg.matmul(a, b)
@@ -29,44 +28,107 @@ namespace {
 ///
 /// By default the LHS is transposed. Set `transposeLHS=false` to
 /// transpose RHS instead.
+FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
+                                                     linalg::MatmulOp matmulOp,
+                                                     bool transposeLHS) {
+  if (!bufferization::hasTensorSemantics(matmulOp))
+    return rewriter.notifyMatchFailure(
+        matmulOp, "only matmul ops with tensors are supported");
+
+  Location loc = matmulOp.getLoc();
+  Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
+  auto type = cast<ShapedType>(input.getType());
+
+  SmallVector<Value> dynamicDims;
+  if (type.isDynamicDim(1))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+  if (type.isDynamicDim(0))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+  ArrayRef<int64_t> shape = type.getShape();
+  Value empty = rewriter.create<tensor::EmptyOp>(
+      loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
+      dynamicDims);
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, input, empty, ArrayRef<int64_t>{1, 0});
+  Operation *newMatmulOp;
+  if (transposeLHS) {
+    newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
+        loc, matmulOp.getResultTypes(),
+        ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
+        matmulOp.getOutputs());
+  } else {
+    newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
+        loc, matmulOp.getResultTypes(),
+        ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
+        matmulOp.getOutputs());
+  }
+  rewriter.replaceOp(matmulOp, newMatmulOp);
+  return newMatmulOp;
+}
+
+/// Pattern to replace
+///
+///   linalg.batch_matmul(a, b)
+///
+/// with
+///
+///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+///
+/// Only the non-batch dimensions are transposed. By default the LHS is
+/// transposed. Set `transposeLHS=false` to transpose RHS instead.
+FailureOr<Operation *>
+mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
+                                   linalg::BatchMatmulOp batchMatmulOp,
+                                   bool transposeLHS) {
+  if (!bufferization::hasTensorSemantics(batchMatmulOp))
+    return rewriter.notifyMatchFailure(
+        batchMatmulOp, "only matmul ops with tensors are supported");
+
+  Location loc = batchMatmulOp.getLoc();
+  Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
+  auto type = cast<ShapedType>(input.getType());
+
+  SmallVector<Value> dynamicDims;
+  if (type.isDynamicDim(0))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+  if (type.isDynamicDim(2))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
+  if (type.isDynamicDim(1))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+
+  ArrayRef<int64_t> shape = type.getShape();
+  Value empty = rewriter.create<tensor::EmptyOp>(
+      loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
+      type.getElementType(), dynamicDims);
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
+  Operation *newMatmulOp;
+  if (transposeLHS) {
+    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
+        loc, batchMatmulOp.getResultTypes(),
+        ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
+        batchMatmulOp.getOutputs());
+  } else {
+    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
+        loc, batchMatmulOp.getResultTypes(),
+        ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
+        batchMatmulOp.getOutputs());
+  }
+  rewriter.replaceOp(batchMatmulOp, newMatmulOp);
+  return newMatmulOp;
+}
+
+namespace {
 struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
   TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
       : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
 
-  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+  LogicalResult matchAndRewrite(linalg::MatmulOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!bufferization::hasTensorSemantics(matmulOp))
-      return rewriter.notifyMatchFailure(
-          matmulOp, "only matmul ops with tensors are supported");
-
-    Location loc = matmulOp.getLoc();
-    Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
-    auto type = cast<ShapedType>(input.getType());
-
-    SmallVector<Value> dynamicDims;
-    if (type.isDynamicDim(1))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
-    if (type.isDynamicDim(0))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
-
-    ArrayRef<int64_t> shape = type.getShape();
-    Value empty = rewriter.create<tensor::EmptyOp>(
-        loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
-        dynamicDims);
-    auto transposeOp = rewriter.create<linalg::TransposeOp>(
-        loc, input, empty, ArrayRef<int64_t>{1, 0});
-    if (transposeLHS) {
-      rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
-          matmulOp, matmulOp.getResultTypes(),
-          ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
-          matmulOp.getOutputs());
-    } else {
-      rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
-          matmulOp, matmulOp.getResultTypes(),
-          ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
-          matmulOp.getOutputs());
+    if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
+      return failure();
     }
-
     return success();
   }
 
@@ -74,57 +136,16 @@ struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
   bool transposeLHS;
 };
 
-/// Pattern to replace
-///
-///   linalg.batch_matmul(a, b)
-///
-/// with
-///
-///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
-///
-/// Only the non-batch dimensions are transposed. By default the LHS is
-/// transposed. Set `transposeLHS=false` to transpose RHS instead.
 struct TransposeBatchMatmul final
     : public OpRewritePattern<linalg::BatchMatmulOp> {
   TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
       : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
 
-  LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
+  LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!bufferization::hasTensorSemantics(batchMatmulOp))
-      return rewriter.notifyMatchFailure(
-          batchMatmulOp, "only matmul ops with tensors are supported");
-
-    Location loc = batchMatmulOp.getLoc();
-    Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
-    auto type = cast<ShapedType>(input.getType());
-
-    SmallVector<Value> dynamicDims;
-    if (type.isDynamicDim(0))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
-    if (type.isDynamicDim(2))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
-    if (type.isDynamicDim(1))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
-
-    ArrayRef<int64_t> shape = type.getShape();
-    Value empty = rewriter.create<tensor::EmptyOp>(
-        loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
-        type.getElementType(), dynamicDims);
-    auto transposeOp = rewriter.create<linalg::TransposeOp>(
-        loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
-    if (transposeLHS) {
-      rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
-          batchMatmulOp, batchMatmulOp.getResultTypes(),
-          ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
-          batchMatmulOp.getOutputs());
-    } else {
-      rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
-          batchMatmulOp, batchMatmulOp.getResultTypes(),
-          ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
-          batchMatmulOp.getOutputs());
+    if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
+      return failure();
     }
-
     return success();
   }
 
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir b/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
index 1d2460f5467a5d..b1f33cfa56327e 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
@@ -2,10 +2,9 @@
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.transpose_matmul %matmul : (!transform.any_op) -> (!transform.any_op)
     %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %0 {
-      transform.apply_patterns.linalg.transpose_matmul
-    } : !transform.any_op
     transform.apply_cse to %0 : !transform.any_op
     transform.apply_patterns to %0 {
       transform.apply_patterns.canonicalization
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir b/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
index eecd76f1ecca7d..41e64c04dc6e59 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
@@ -2,10 +2,9 @@
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.transpose_matmul %matmul <rhs> : (!transform.any_op) -> (!transform.any_op)
     %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %0 {
-      transform.apply_patterns.linalg.transpose_matmul <rhs>
-    } : !transform.any_op
     transform.apply_cse to %0 : !transform.any_op
     transform.apply_patterns to %0 {
       transform.apply_patterns.canonicalization

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I assume most of the PR just moves the code around so I didn't review it in detail. LMK if attention is needed somewhere.

@c-rhodes
Copy link
Collaborator Author

Thanks! I assume most of the PR just moves the code around so I didn't review it in detail. LMK if attention is needed somewhere.

Thanks for speedy review! It's mostly mechanical, no changes to the actual rewrites.

@c-rhodes c-rhodes merged commit be1c72d into llvm:main Apr 23, 2024
@c-rhodes c-rhodes deleted the linalg-matmul-transpose-move-transform-pattern branch April 23, 2024 09:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants