-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] raise generic to named ops. #110421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] raise generic to named ops. #110421
Conversation
Add support for specializing linalg.broadcast and linalg.transform from generic. Also, refactoring to reuse specialization checks.
@llvm/pr-subscribers-mlir-linalg Author: Javed Absar (javedabsar1) ChangesAdd support for specializing linalg.broadcast and linalg.transform from generic. Full diff: https://github.com/llvm/llvm-project/pull/110421.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 0fcaa96ade4031..6f1c243cc4396d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -120,6 +120,16 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp,
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
bool isaCopyOpInterface(LinalgOp linalgOp);
+/// Checks whether `genericOp` is semantically equivalent to a
+/// `linalg.broadcast`. Returns broadcast dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaBroadcastOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a
+/// `linalg.transpose`. Returns permuted dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaTransposeOpInterface(GenericOp genericOp);
+
/// Checks whether a given `genericOp` is semantically equivalent to a single
/// linalgelementwise unary op. e.g. linalg.exp.
/// A linalg.generic body could be a series of unary elementwise ops e.g.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 0b5191664a9e2f..5842128091972a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -22,6 +22,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
+#include <numeric>
using namespace mlir;
using namespace mlir::linalg;
@@ -49,18 +50,41 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}
+// Returns true if all loops of the linalgOp are parallel
+static bool isAllParallel(LinalgOp op) {
+ return op.getNumParallelLoops() == op.getNumLoops();
+}
+
+// Returns true if and only if linalgOp takes one input and one init.
+static bool isSingleInputOutput(LinalgOp op) {
+ return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1;
+}
+// Returns true if genericOp body is just a yieldOp that yields
+// input operand as result.
+static bool isSingleYieldOp(GenericOp op) {
+ if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1)
+ return false;
+
+ Block *body = op.getBody();
+ if (body->getOperations().size() != 1)
+ return false;
+
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+ if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+ yieldOp->getOperand(0) != body->getArgument(0))
+ return false;
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// CopyOpInterface implementation
//===----------------------------------------------------------------------===//
bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
- // Structural.
- if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+ // Structural and operands
+ if (!isAllParallel(linalgOp) || !isSingleInputOutput(linalgOp))
return false;
- // Operands and maps.
- if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
- return false;
auto mapRange = linalgOp.getIndexingMapsArray();
if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
!mapRange.back().isIdentity()) {
@@ -75,8 +99,8 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
//===----------------------------------------------------------------------===//
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
// Structural.
- if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
- genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+ if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+ !isSingleYieldOp(genericOp))
return std::nullopt;
// Input should be referenced and init should not.
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
OpOperand *value = genericOp.getDpsInputOperand(0);
if (!genericOp.isScalar(value))
return std::nullopt;
+ return value->get();
+}
- Block *body = genericOp.getBody();
- if (body->getOperations().size() != 1)
+//===----------------------------------------------------------------------===//
+// BroadcastOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaBroadcastOpInterface(GenericOp genericOp) {
+ // Structural.
+ if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+ !isSingleYieldOp(genericOp))
return std::nullopt;
- auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
- if (!yieldOp || yieldOp.getNumOperands() != 1 ||
- yieldOp->getOperand(0) != body->getArgument(0))
+ auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
+ auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
+ if (!isa<MemRefType, RankedTensorType>(t0) ||
+ !isa<MemRefType, RankedTensorType>(t1))
return std::nullopt;
- return value->get();
+
+ // Check output is identity map. Injective function could also be
+ // a permutation of indices and expressible in linalg.generic but
+ // is not expressible for named broadcast op.
+ auto dstMap = genericOp.getIndexingMapsArray()[1];
+ if (!dstMap.isIdentity())
+ return std::nullopt;
+
+ SmallVector<int64_t> position;
+ auto srcMap = genericOp.getIndexingMapsArray()[0];
+
+ // Check input map is monotonically increasing DimIds.
+ for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
+ auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
+ if (!expr)
+ return std::nullopt;
+ int64_t pos = expr.getPosition();
+ if (i > 0 && pos <= position[i - 1])
+ return std::nullopt;
+ position.push_back(expr.getPosition());
+ }
+
+ SmallVector<int64_t> broadcastedDims;
+ auto numDims = srcMap.getNumDims();
+ for (auto dim : llvm::seq<int64_t>(0, numDims)) {
+ if (!llvm::is_contained(position, dim))
+ broadcastedDims.push_back(dim);
+ }
+ return broadcastedDims;
+}
+
+//===----------------------------------------------------------------------===//
+// TranposeOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaTransposeOpInterface(GenericOp genericOp) {
+ // Structural.
+ if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+ !isSingleYieldOp(genericOp))
+ return std::nullopt;
+
+ // mapping checks.
+ auto mapRange = genericOp.getIndexingMapsArray();
+ if (mapRange.size() != 2 || !mapRange.back().isIdentity() ||
+ !mapRange.front().isPermutation())
+ return std::nullopt;
+
+ SmallVector<int64_t> permutation;
+ auto map = mapRange.front();
+ for (unsigned i = 0; i < map.getNumResults(); ++i) {
+ auto expr = llvm::cast<AffineDimExpr>(map.getResults()[i]);
+ permutation.push_back(expr.getPosition());
+ }
+ return permutation;
}
//===----------------------------------------------------------------------===//
@@ -106,8 +192,7 @@ static bool
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
// Check all loops are parallel.
- if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
- genericOp.getNumLoops() < 1)
+ if (!isAllParallel(genericOp) || genericOp.getNumLoops() < 1)
return false;
// Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4d7b748d7200e2..dfafffce9d9b60 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -259,18 +259,43 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
+ // Copy
if (isaCopyOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
+ // Fill
if (isaFillOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
+ // Broadcast
+ std::optional<SmallVector<int64_t>> equivalentToBroadcast =
+ isaBroadcastOpInterface(genericOp);
+ if (equivalentToBroadcast) {
+ auto dims = *equivalentToBroadcast;
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+ dims);
+ return namedOp;
+ }
+
+ // Transpose
+ std::optional<SmallVector<int64_t>> equivalentToTranspose =
+ isaTransposeOpInterface(genericOp);
+ if (equivalentToTranspose) {
+ auto permutation = *equivalentToTranspose;
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+ permutation);
+ return namedOp;
+ }
+
+ // Elementwise Unary
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
Operation *op = &genericOp.getBody()->front();
if (isa<math::ExpOp>(op)) {
@@ -279,6 +304,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
}
+ // Elementwise Binary
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
bool swap = areBinOpsSwapped(genericOp);
Operation *op = &genericOp.getBody()->front();
@@ -300,6 +326,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
}
+ // Contraction - e.g. matmul
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
new file mode 100644
index 00000000000000..d6915ec8fbbf6f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: broadcast_first_dimension
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) dimensions = [0]
+//
+func.func @broadcast_first_dimension(%A: tensor<?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %res = linalg.broadcast ins(%A: tensor<?x?xf32>) outs(%Out: tensor<?x?x?xf32>) dimensions = [0]
+ return %res : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: broadcast_mid_dimension
+// CHECK-SAME: %[[A:.+]]: tensor<3x5xf32>, %[[Out:.+]]: tensor<3x4x5xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<3x5xf32>) outs(%[[Out]] : tensor<3x4x5xf32>) dimensions = [1]
+//
+func.func @broadcast_mid_dimension(%A: tensor<3x5xf32>, %Out: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
+ %res = linalg.broadcast ins(%A: tensor<3x5xf32>) outs(%Out: tensor<3x4x5xf32>) dimensions = [1]
+ return %res : tensor<3x4x5xf32>
+}
+
+
+// CHECK-LABEL: broadcast_multiple_dimensions
+// CHECK-SAME: %[[A:.+]]: tensor<4x5x7xf32>, %[[Out:.+]]: tensor<3x4x5x6x7x8x9xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<4x5x7xf32>) outs(%[[Out]] : tensor<3x4x5x6x7x8x9xf32>) dimensions = [0, 3, 5, 6]
+//
+func.func @broadcast_multiple_dimensions(%A: tensor<4x5x7xf32>, %Out: tensor<3x4x5x6x7x8x9xf32>) -> tensor<3x4x5x6x7x8x9xf32> {
+ %res = linalg.broadcast ins(%A: tensor<4x5x7xf32>) outs(%Out: tensor<3x4x5x6x7x8x9xf32>) dimensions = [0,3,5,6]
+ return %res : tensor<3x4x5x6x7x8x9xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
new file mode 100644
index 00000000000000..ebc42c903e6e3e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: linalg_transpose
+// CHECK-SAME: %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0]
+//
+func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
+ %res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0]
+ return %res : tensor<64x16xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 35679db7412f30..31f2f6b1ab513f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -4,18 +4,6 @@
#map1 = affine_map<(d0, d1) -> (d0)>
#map2 = affine_map<(d0, d1) -> (d1, d0)>
-func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
- // expected-note @below {{when applied to this op}}
- linalg.generic {
- indexing_maps = [#map1, #map],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- }
- return
-}
-
func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
|
@llvm/pr-subscribers-mlir Author: Javed Absar (javedabsar1) ChangesAdd support for specializing linalg.broadcast and linalg.transform from generic. Full diff: https://github.com/llvm/llvm-project/pull/110421.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 0fcaa96ade4031..6f1c243cc4396d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -120,6 +120,16 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp,
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
bool isaCopyOpInterface(LinalgOp linalgOp);
+/// Checks whether `genericOp` is semantically equivalent to a
+/// `linalg.broadcast`. Returns broadcast dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaBroadcastOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a
+/// `linalg.transpose`. Returns permuted dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaTransposeOpInterface(GenericOp genericOp);
+
/// Checks whether a given `genericOp` is semantically equivalent to a single
/// linalgelementwise unary op. e.g. linalg.exp.
/// A linalg.generic body could be a series of unary elementwise ops e.g.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 0b5191664a9e2f..5842128091972a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -22,6 +22,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
+#include <numeric>
using namespace mlir;
using namespace mlir::linalg;
@@ -49,18 +50,41 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}
+// Returns true if all loops of the linalgOp are parallel
+static bool isAllParallel(LinalgOp op) {
+ return op.getNumParallelLoops() == op.getNumLoops();
+}
+
+// Returns true if and only if linalgOp takes one input and one init.
+static bool isSingleInputOutput(LinalgOp op) {
+ return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1;
+}
+// Returns true if genericOp body is just a yieldOp that yields
+// input operand as result.
+static bool isSingleYieldOp(GenericOp op) {
+ if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1)
+ return false;
+
+ Block *body = op.getBody();
+ if (body->getOperations().size() != 1)
+ return false;
+
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+ if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+ yieldOp->getOperand(0) != body->getArgument(0))
+ return false;
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// CopyOpInterface implementation
//===----------------------------------------------------------------------===//
bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
- // Structural.
- if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+ // Structural and operands
+ if (!isAllParallel(linalgOp) || !isSingleInputOutput(linalgOp))
return false;
- // Operands and maps.
- if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
- return false;
auto mapRange = linalgOp.getIndexingMapsArray();
if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
!mapRange.back().isIdentity()) {
@@ -75,8 +99,8 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
//===----------------------------------------------------------------------===//
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
// Structural.
- if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
- genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+ if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+ !isSingleYieldOp(genericOp))
return std::nullopt;
// Input should be referenced and init should not.
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
OpOperand *value = genericOp.getDpsInputOperand(0);
if (!genericOp.isScalar(value))
return std::nullopt;
+ return value->get();
+}
- Block *body = genericOp.getBody();
- if (body->getOperations().size() != 1)
+//===----------------------------------------------------------------------===//
+// BroadcastOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaBroadcastOpInterface(GenericOp genericOp) {
+ // Structural.
+ if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+ !isSingleYieldOp(genericOp))
return std::nullopt;
- auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
- if (!yieldOp || yieldOp.getNumOperands() != 1 ||
- yieldOp->getOperand(0) != body->getArgument(0))
+ auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
+ auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
+ if (!isa<MemRefType, RankedTensorType>(t0) ||
+ !isa<MemRefType, RankedTensorType>(t1))
return std::nullopt;
- return value->get();
+
+ // Check output is identity map. Injective function could also be
+ // a permutation of indices and expressible in linalg.generic but
+ // is not expressible for named broadcast op.
+ auto dstMap = genericOp.getIndexingMapsArray()[1];
+ if (!dstMap.isIdentity())
+ return std::nullopt;
+
+ SmallVector<int64_t> position;
+ auto srcMap = genericOp.getIndexingMapsArray()[0];
+
+ // Check input map is monotonically increasing DimIds.
+ for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
+ auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
+ if (!expr)
+ return std::nullopt;
+ int64_t pos = expr.getPosition();
+ if (i > 0 && pos <= position[i - 1])
+ return std::nullopt;
+ position.push_back(expr.getPosition());
+ }
+
+ SmallVector<int64_t> broadcastedDims;
+ auto numDims = srcMap.getNumDims();
+ for (auto dim : llvm::seq<int64_t>(0, numDims)) {
+ if (!llvm::is_contained(position, dim))
+ broadcastedDims.push_back(dim);
+ }
+ return broadcastedDims;
+}
+
+//===----------------------------------------------------------------------===//
+// TranposeOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaTransposeOpInterface(GenericOp genericOp) {
+ // Structural.
+ if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+ !isSingleYieldOp(genericOp))
+ return std::nullopt;
+
+ // mapping checks.
+ auto mapRange = genericOp.getIndexingMapsArray();
+ if (mapRange.size() != 2 || !mapRange.back().isIdentity() ||
+ !mapRange.front().isPermutation())
+ return std::nullopt;
+
+ SmallVector<int64_t> permutation;
+ auto map = mapRange.front();
+ for (unsigned i = 0; i < map.getNumResults(); ++i) {
+ auto expr = llvm::cast<AffineDimExpr>(map.getResults()[i]);
+ permutation.push_back(expr.getPosition());
+ }
+ return permutation;
}
//===----------------------------------------------------------------------===//
@@ -106,8 +192,7 @@ static bool
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
// Check all loops are parallel.
- if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
- genericOp.getNumLoops() < 1)
+ if (!isAllParallel(genericOp) || genericOp.getNumLoops() < 1)
return false;
// Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4d7b748d7200e2..dfafffce9d9b60 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -259,18 +259,43 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
+ // Copy
if (isaCopyOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
+ // Fill
if (isaFillOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
+ // Broadcast
+ std::optional<SmallVector<int64_t>> equivalentToBroadcast =
+ isaBroadcastOpInterface(genericOp);
+ if (equivalentToBroadcast) {
+ auto dims = *equivalentToBroadcast;
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+ dims);
+ return namedOp;
+ }
+
+ // Transpose
+ std::optional<SmallVector<int64_t>> equivalentToTranspose =
+ isaTransposeOpInterface(genericOp);
+ if (equivalentToTranspose) {
+ auto permutation = *equivalentToTranspose;
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+ permutation);
+ return namedOp;
+ }
+
+ // Elementwise Unary
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
Operation *op = &genericOp.getBody()->front();
if (isa<math::ExpOp>(op)) {
@@ -279,6 +304,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
}
+ // Elementwise Binary
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
bool swap = areBinOpsSwapped(genericOp);
Operation *op = &genericOp.getBody()->front();
@@ -300,6 +326,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
}
+ // Contraction - e.g. matmul
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
new file mode 100644
index 00000000000000..d6915ec8fbbf6f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: broadcast_first_dimension
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) dimensions = [0]
+//
+func.func @broadcast_first_dimension(%A: tensor<?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %res = linalg.broadcast ins(%A: tensor<?x?xf32>) outs(%Out: tensor<?x?x?xf32>) dimensions = [0]
+ return %res : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: broadcast_mid_dimension
+// CHECK-SAME: %[[A:.+]]: tensor<3x5xf32>, %[[Out:.+]]: tensor<3x4x5xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<3x5xf32>) outs(%[[Out]] : tensor<3x4x5xf32>) dimensions = [1]
+//
+func.func @broadcast_mid_dimension(%A: tensor<3x5xf32>, %Out: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
+ %res = linalg.broadcast ins(%A: tensor<3x5xf32>) outs(%Out: tensor<3x4x5xf32>) dimensions = [1]
+ return %res : tensor<3x4x5xf32>
+}
+
+
+// CHECK-LABEL: broadcast_multiple_dimensions
+// CHECK-SAME: %[[A:.+]]: tensor<4x5x7xf32>, %[[Out:.+]]: tensor<3x4x5x6x7x8x9xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<4x5x7xf32>) outs(%[[Out]] : tensor<3x4x5x6x7x8x9xf32>) dimensions = [0, 3, 5, 6]
+//
+func.func @broadcast_multiple_dimensions(%A: tensor<4x5x7xf32>, %Out: tensor<3x4x5x6x7x8x9xf32>) -> tensor<3x4x5x6x7x8x9xf32> {
+ %res = linalg.broadcast ins(%A: tensor<4x5x7xf32>) outs(%Out: tensor<3x4x5x6x7x8x9xf32>) dimensions = [0,3,5,6]
+ return %res : tensor<3x4x5x6x7x8x9xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
new file mode 100644
index 00000000000000..ebc42c903e6e3e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: linalg_transpose
+// CHECK-SAME: %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0]
+//
+func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
+ %res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0]
+ return %res : tensor<64x16xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 35679db7412f30..31f2f6b1ab513f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -4,18 +4,6 @@
#map1 = affine_map<(d0, d1) -> (d0)>
#map2 = affine_map<(d0, d1) -> (d1, d0)>
-func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
- // expected-note @below {{when applied to this op}}
- linalg.generic {
- indexing_maps = [#map1, #map],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- }
- return
-}
-
func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
|
@shahidact this PR intersects the "is broadcast" functionality with your PR #104783. Please make sure we use similar ideas across the linalg dialect. Thanks! |
Hi @rengolin and shahidact. My generic to naming is based on broadcast . Not a coupling of matmul and transpose/broadcast. @rengolin please review my patch on merits of current linalg. |
Ping! |
|
Hey, sorry, I did not mean to review this PR with my tag, just to alert @shahidact to your PR so he can check his own PR against it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments, otherwise, looks good to me. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about the delay with this Javed! Looks good modulo some small suggestions. Also, could you add some negative tests? Thanks!
//===----------------------------------------------------------------------===// | ||
// CopyOpInterface implementation | ||
//===----------------------------------------------------------------------===// | ||
|
||
bool linalg::isaCopyOpInterface(LinalgOp linalgOp) { | ||
// Structural. | ||
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) | ||
// Structural and operands |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] I don't understand this comment :) appreciate that you are effectively inheriting this, but let's clarify. Does "Structural and operands" mean "Check the structure (no parallel dims?) and the operands (single input/output?)".
Also trying to make sure I understand 😅
// Returns true if all loops of the linalgOp are parallel | ||
static bool isAllParallel(LinalgOp op) { | ||
return op.getNumParallelLoops() == op.getNumLoops(); | ||
} | ||
|
||
// Returns true if and only if linalgOp takes one input and one init. | ||
static bool isSingleInputOutput(LinalgOp op) { | ||
return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1; | ||
} | ||
// Returns true if genericOp body is just a yieldOp that yields | ||
// input operand as result. | ||
static bool isSingleYieldOp(GenericOp op) { | ||
if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1) | ||
return false; | ||
|
||
Block *body = op.getBody(); | ||
if (body->getOperations().size() != 1) | ||
return false; | ||
|
||
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back()); | ||
if (!yieldOp || yieldOp.getNumOperands() != 1 || | ||
yieldOp->getOperand(0) != body->getArgument(0)) | ||
return false; | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these belong here? IIUC, the comment above ("Interface utility functions") refers to ODS/TableGen "interfaces" (i.e. none of these is a InterfaceMethod
).
Having said that, why not add them to the interface?
if (!isa<MemRefType, RankedTensorType>(t0) || | ||
!isa<MemRefType, RankedTensorType>(t1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What else could it be? Perhaps checking for ShapedType
would be enough?
!isSingleYieldOp(genericOp)) | ||
return std::nullopt; | ||
|
||
// mapping checks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit]
// mapping checks. | |
// Check the maps. |
// CHECK-NOT: linalg.generic | ||
// CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0] | ||
// | ||
func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Skip linalg
in func name (repeating info already available)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps add a 3d, 1d cases? And identity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aded 3D test. Thanks for the suggestion.
1D test on transpose will just get dce-d out.
Thanks Renato for reviewing this. yes i was bit confused about your earlier statement. apologies. |
Added. Thanks. |
Hi @banach-space @rengolin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my comments, LGTM % the comment re the test file (did you mean to expand roundtrip.mlir?)
Renato has effectively already approved, but I'd wait till tomorrow to make sure he has no further comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why a new file instead of re-using roundtrip.mlir? Note that this file is called "roundtrip-broadcast.mlir", but it test both broadcasts and transposes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubled checked transpose and broadcast are in separate file e.g. the linalg.transpose are in roundtrip-transpose.mlir. It may be that in the browser it is appearing mashed up across comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Thanks a lot @rengolin and @banach-space for helping improve the patch. |
Add support for specializing linalg.broadcast and linalg.transform from generic. Also, does some refactoring to reuse specialization checks, migrating some common uses to op interface methods.
Add support for specializing linalg.broadcast and linalg.transform from generic. Also, does some refactoring to reuse specialization checks, migrating some common uses to op interface methods.