Skip to content

[mlir][linalg][elementwise] Fold transpose into new elementwise #130207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 12, 2025

Conversation

javedabsar1
Copy link
Contributor

@javedabsar1 javedabsar1 commented Mar 7, 2025

Fold transpose into new elementwise Op which has affine-map attached.
Will add broadcast folding in next diff.

@llvmbot
Copy link
Member

llvmbot commented Mar 7, 2025

@llvm/pr-subscribers-mlir

Author: Javed Absar (javedabsar1)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/130207.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+13-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+42)
  • (added) mlir/test/Dialect/Linalg/elementwise/fold.mlir (+43)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..f7b1d2c9dfcb3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -601,12 +601,24 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
       [{
         buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
           attributes, ElementwiseOp::getRegionBuilder());
-      }]>
+      }]>,
+
+     OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+          "ElementwiseKindAttr":$kind,
+          "ArrayAttr":$indexingMaps,
+          CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("kind", kind);
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+                          attributes, ElementwiseOp::getRegionBuilder());
+       }]>
     ];
 
   let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
       /// Get the arity enum corresponding to the kind of op, e.g. if arg is
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..f6b7c32659bb5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -4285,6 +4286,47 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+namespace {
+struct FoldTranspose : public OpRewritePattern<ElementwiseOp> {
+  using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ElementwiseOp op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    SmallVector<Value> newIns;
+    SmallVector<AffineMap> newMaps;
+    for (OpOperand *operand : op.getDpsInputOperands()) {
+      AffineMap map = op.getMatchingIndexingMap(operand);
+      auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
+
+      if (!map.isIdentity() || !transposeOp) {
+        // push in original operand and its map.
+        newIns.push_back(operand->get());
+        newMaps.push_back(map);
+        continue;
+      }
+      newIns.push_back(transposeOp.getInput());
+      // push in transposeOp's inverse permutation map.
+      newMaps.push_back(transposeOp.getMatchingIndexingMap(
+          transposeOp.getDpsInputOperand(0)));
+      changed = true;
+    }
+    if (!changed)
+      return failure();
+    newMaps.push_back(op.getIndexingMapsArray().back());
+
+    rewriter.replaceOpWithNewOp<ElementwiseOp>(
+        op, newIns, op.getDpsInits()[0], op.getKindAttr(),
+        rewriter.getAffineMapArrayAttr(newMaps));
+    return success();
+  }
+};
+} // namespace
+void ElementwiseOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<FoldTranspose>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
new file mode 100644
index 0000000000000..7b2ff0b6de12e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+//
+// CHECK:  func.func @unary_transpose(%[[A:.+]]: tensor<16x8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:       indexing_maps = [#[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME:       ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT:    return %[[RES]] : tensor<8x16x32xf32>
+//
+func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
+  %empty = tensor.empty() : tensor<8x16x32xf32>
+  %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty :  tensor<8x16x32xf32>) permutation = [1, 0, 2]
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+                          ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//
+// CHECK:  func.func @binary_transposed(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:              indexing_maps = [#[[IDENTITY]], #[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME:              ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT:  return %[[RES]] : tensor<?x?xf32>
+//
+func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) ->  tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
+
+  %empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+  %transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty :  tensor<?x?xf32>) permutation = [1, 0]
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+                          ins(%A, %transposed_B : tensor<?x?xf32>,  tensor<?x?xf32>)
+                          outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}

@llvmbot
Copy link
Member

llvmbot commented Mar 7, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Javed Absar (javedabsar1)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/130207.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+13-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+42)
  • (added) mlir/test/Dialect/Linalg/elementwise/fold.mlir (+43)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..f7b1d2c9dfcb3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -601,12 +601,24 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
       [{
         buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
           attributes, ElementwiseOp::getRegionBuilder());
-      }]>
+      }]>,
+
+     OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+          "ElementwiseKindAttr":$kind,
+          "ArrayAttr":$indexingMaps,
+          CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("kind", kind);
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+                          attributes, ElementwiseOp::getRegionBuilder());
+       }]>
     ];
 
   let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
       /// Get the arity enum corresponding to the kind of op, e.g. if arg is
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..f6b7c32659bb5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -4285,6 +4286,47 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+namespace {
+struct FoldTranspose : public OpRewritePattern<ElementwiseOp> {
+  using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ElementwiseOp op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    SmallVector<Value> newIns;
+    SmallVector<AffineMap> newMaps;
+    for (OpOperand *operand : op.getDpsInputOperands()) {
+      AffineMap map = op.getMatchingIndexingMap(operand);
+      auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
+
+      if (!map.isIdentity() || !transposeOp) {
+        // push in original operand and its map.
+        newIns.push_back(operand->get());
+        newMaps.push_back(map);
+        continue;
+      }
+      newIns.push_back(transposeOp.getInput());
+      // push in transposeOp's inverse permutation map.
+      newMaps.push_back(transposeOp.getMatchingIndexingMap(
+          transposeOp.getDpsInputOperand(0)));
+      changed = true;
+    }
+    if (!changed)
+      return failure();
+    newMaps.push_back(op.getIndexingMapsArray().back());
+
+    rewriter.replaceOpWithNewOp<ElementwiseOp>(
+        op, newIns, op.getDpsInits()[0], op.getKindAttr(),
+        rewriter.getAffineMapArrayAttr(newMaps));
+    return success();
+  }
+};
+} // namespace
+void ElementwiseOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<FoldTranspose>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
new file mode 100644
index 0000000000000..7b2ff0b6de12e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+//
+// CHECK:  func.func @unary_transpose(%[[A:.+]]: tensor<16x8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:       indexing_maps = [#[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME:       ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT:    return %[[RES]] : tensor<8x16x32xf32>
+//
+func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
+  %empty = tensor.empty() : tensor<8x16x32xf32>
+  %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty :  tensor<8x16x32xf32>) permutation = [1, 0, 2]
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+                          ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//
+// CHECK:  func.func @binary_transposed(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:              indexing_maps = [#[[IDENTITY]], #[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME:              ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT:  return %[[RES]] : tensor<?x?xf32>
+//
+func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) ->  tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
+
+  %empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+  %transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty :  tensor<?x?xf32>) permutation = [1, 0]
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+                          ins(%A, %transposed_B : tensor<?x?xf32>,  tensor<?x?xf32>)
+                          outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks good to me. Wait a day for others to check again.

@javedabsar1 javedabsar1 merged commit ecf4d99 into llvm:main Mar 12, 2025
11 checks passed
frederik-h pushed a commit to frederik-h/llvm-project that referenced this pull request Mar 18, 2025
…#130207)

Fold transpose into new elementwise Op which has affine-map attached.
Will add broadcast folding in next diff.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants