-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg][elementwise] Fold transpose into new elementwise #130207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Javed Absar (javedabsar1) ChangesFull diff: https://github.com/llvm/llvm-project/pull/130207.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..f7b1d2c9dfcb3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -601,12 +601,24 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
[{
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, ElementwiseOp::getRegionBuilder());
- }]>
+ }]>,
+
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+ "ElementwiseKindAttr":$kind,
+ "ArrayAttr":$indexingMaps,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("kind", kind);
+ $_state.addAttribute("indexing_maps", indexingMaps);
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, ElementwiseOp::getRegionBuilder());
+ }]>
];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{
/// Get the arity enum corresponding to the kind of op, e.g. if arg is
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..f6b7c32659bb5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -4285,6 +4286,47 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+namespace {
+struct FoldTranspose : public OpRewritePattern<ElementwiseOp> {
+ using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ElementwiseOp op,
+ PatternRewriter &rewriter) const override {
+ bool changed = false;
+ SmallVector<Value> newIns;
+ SmallVector<AffineMap> newMaps;
+ for (OpOperand *operand : op.getDpsInputOperands()) {
+ AffineMap map = op.getMatchingIndexingMap(operand);
+ auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
+
+ if (!map.isIdentity() || !transposeOp) {
+ // push in original operand and its map.
+ newIns.push_back(operand->get());
+ newMaps.push_back(map);
+ continue;
+ }
+ newIns.push_back(transposeOp.getInput());
+ // push in transposeOp's inverse permutation map.
+ newMaps.push_back(transposeOp.getMatchingIndexingMap(
+ transposeOp.getDpsInputOperand(0)));
+ changed = true;
+ }
+ if (!changed)
+ return failure();
+ newMaps.push_back(op.getIndexingMapsArray().back());
+
+ rewriter.replaceOpWithNewOp<ElementwiseOp>(
+ op, newIns, op.getDpsInits()[0], op.getKindAttr(),
+ rewriter.getAffineMapArrayAttr(newMaps));
+ return success();
+ }
+};
+} // namespace
+void ElementwiseOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldTranspose>(context);
+}
+
//===----------------------------------------------------------------------===//
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
new file mode 100644
index 0000000000000..7b2ff0b6de12e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+//
+// CHECK: func.func @unary_transpose(%[[A:.+]]: tensor<16x8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: indexing_maps = [#[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
+//
+func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %empty = tensor.empty() : tensor<8x16x32xf32>
+ %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+ ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//
+// CHECK: func.func @binary_transposed(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
+//
+func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
+
+ %empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+ %transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+ ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %result : tensor<?x?xf32>
+}
|
@llvm/pr-subscribers-mlir-linalg Author: Javed Absar (javedabsar1) ChangesFull diff: https://github.com/llvm/llvm-project/pull/130207.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..f7b1d2c9dfcb3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -601,12 +601,24 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
[{
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, ElementwiseOp::getRegionBuilder());
- }]>
+ }]>,
+
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+ "ElementwiseKindAttr":$kind,
+ "ArrayAttr":$indexingMaps,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("kind", kind);
+ $_state.addAttribute("indexing_maps", indexingMaps);
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, ElementwiseOp::getRegionBuilder());
+ }]>
];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{
/// Get the arity enum corresponding to the kind of op, e.g. if arg is
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..f6b7c32659bb5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -4285,6 +4286,47 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+namespace {
+struct FoldTranspose : public OpRewritePattern<ElementwiseOp> {
+ using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ElementwiseOp op,
+ PatternRewriter &rewriter) const override {
+ bool changed = false;
+ SmallVector<Value> newIns;
+ SmallVector<AffineMap> newMaps;
+ for (OpOperand *operand : op.getDpsInputOperands()) {
+ AffineMap map = op.getMatchingIndexingMap(operand);
+ auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
+
+ if (!map.isIdentity() || !transposeOp) {
+ // push in original operand and its map.
+ newIns.push_back(operand->get());
+ newMaps.push_back(map);
+ continue;
+ }
+ newIns.push_back(transposeOp.getInput());
+ // push in transposeOp's inverse permutation map.
+ newMaps.push_back(transposeOp.getMatchingIndexingMap(
+ transposeOp.getDpsInputOperand(0)));
+ changed = true;
+ }
+ if (!changed)
+ return failure();
+ newMaps.push_back(op.getIndexingMapsArray().back());
+
+ rewriter.replaceOpWithNewOp<ElementwiseOp>(
+ op, newIns, op.getDpsInits()[0], op.getKindAttr(),
+ rewriter.getAffineMapArrayAttr(newMaps));
+ return success();
+ }
+};
+} // namespace
+void ElementwiseOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldTranspose>(context);
+}
+
//===----------------------------------------------------------------------===//
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
new file mode 100644
index 0000000000000..7b2ff0b6de12e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+//
+// CHECK: func.func @unary_transpose(%[[A:.+]]: tensor<16x8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: indexing_maps = [#[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
+//
+func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %empty = tensor.empty() : tensor<8x16x32xf32>
+ %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+ ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//
+// CHECK: func.func @binary_transposed(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
+//
+func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
+
+ %empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+ %transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+ ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %result : tensor<?x?xf32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks good to me. Wait a day for others to check again.
…#130207) Fold transpose into new elementwise Op which has affine-map attached. Will add broadcast folding in next diff.
Fold transpose into new elementwise Op which has affine-map attached.
Will add broadcast folding in next diff.