Skip to content

[mlir][linalg] raise generic to named ops. #110421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 11, 2024

Conversation

javedabsar1
Copy link
Contributor

@javedabsar1 javedabsar1 commented Sep 29, 2024

Add support for specializing linalg.broadcast and linalg.transform from generic. Also, does some refactoring to reuse specialization checks, migrating some common uses to op interface methods.

Add support for specializing linalg.broadcast and linalg.transform
from generic. Also, refactoring to reuse specialization checks.
@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Javed Absar (javedabsar1)

Changes

Add support for specializing linalg.broadcast and linalg.transform from generic.
Also, does some refactoring to reuse specialization checks.


Full diff: https://github.com/llvm/llvm-project/pull/110421.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h (+10)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+100-15)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+27)
  • (added) mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir (+32)
  • (added) mlir/test/Dialect/Linalg/roundtrip-transpose.mlir (+11)
  • (modified) mlir/test/Dialect/Linalg/transform-op-specialize.mlir (-12)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 0fcaa96ade4031..6f1c243cc4396d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -120,6 +120,16 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp,
 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
 bool isaCopyOpInterface(LinalgOp linalgOp);
 
+/// Checks whether `genericOp` is semantically equivalent to a
+///  `linalg.broadcast`. Returns broadcast dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaBroadcastOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a
+///  `linalg.transpose`. Returns permuted dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaTransposeOpInterface(GenericOp genericOp);
+
 /// Checks whether a given `genericOp` is semantically equivalent to a single
 /// linalgelementwise unary op. e.g. linalg.exp.
 /// A linalg.generic body could be a series of unary elementwise ops e.g.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 0b5191664a9e2f..5842128091972a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include <algorithm>
+#include <numeric>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -49,18 +50,41 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
   return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
 }
 
+// Returns true if all loops of the linalgOp are parallel
+static bool isAllParallel(LinalgOp op) {
+  return op.getNumParallelLoops() == op.getNumLoops();
+}
+
+// Returns true if and only if linalgOp takes one input and one init.
+static bool isSingleInputOutput(LinalgOp op) {
+  return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1;
+}
+// Returns true if genericOp body is just a yieldOp that yields
+// input operand as result.
+static bool isSingleYieldOp(GenericOp op) {
+  if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1)
+    return false;
+
+  Block *body = op.getBody();
+  if (body->getOperations().size() != 1)
+    return false;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0) != body->getArgument(0))
+    return false;
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // CopyOpInterface implementation
 //===----------------------------------------------------------------------===//
 
 bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
-  // Structural.
-  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+  // Structural and operands
+  if (!isAllParallel(linalgOp) || !isSingleInputOutput(linalgOp))
     return false;
 
-  // Operands and maps.
-  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
-    return false;
   auto mapRange = linalgOp.getIndexingMapsArray();
   if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
       !mapRange.back().isIdentity()) {
@@ -75,8 +99,8 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
 //===----------------------------------------------------------------------===//
 std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
   // Structural.
-  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
-      genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+      !isSingleYieldOp(genericOp))
     return std::nullopt;
 
   // Input should be referenced and init should not.
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
   OpOperand *value = genericOp.getDpsInputOperand(0);
   if (!genericOp.isScalar(value))
     return std::nullopt;
+  return value->get();
+}
 
-  Block *body = genericOp.getBody();
-  if (body->getOperations().size() != 1)
+//===----------------------------------------------------------------------===//
+// BroadcastOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaBroadcastOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+      !isSingleYieldOp(genericOp))
     return std::nullopt;
 
-  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
-  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
-      yieldOp->getOperand(0) != body->getArgument(0))
+  auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
+  auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
+  if (!isa<MemRefType, RankedTensorType>(t0) ||
+      !isa<MemRefType, RankedTensorType>(t1))
     return std::nullopt;
-  return value->get();
+
+  // Check output is identity map. Injective function could also be
+  // a permutation of indices and expressible in linalg.generic but
+  // is not expressible for named broadcast op.
+  auto dstMap = genericOp.getIndexingMapsArray()[1];
+  if (!dstMap.isIdentity())
+    return std::nullopt;
+
+  SmallVector<int64_t> position;
+  auto srcMap = genericOp.getIndexingMapsArray()[0];
+
+  // Check input map is monotonically increasing DimIds.
+  for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
+    auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
+    if (!expr)
+      return std::nullopt;
+    int64_t pos = expr.getPosition();
+    if (i > 0 && pos <= position[i - 1])
+      return std::nullopt;
+    position.push_back(expr.getPosition());
+  }
+
+  SmallVector<int64_t> broadcastedDims;
+  auto numDims = srcMap.getNumDims();
+  for (auto dim : llvm::seq<int64_t>(0, numDims)) {
+    if (!llvm::is_contained(position, dim))
+      broadcastedDims.push_back(dim);
+  }
+  return broadcastedDims;
+}
+
+//===----------------------------------------------------------------------===//
+// TranposeOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaTransposeOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+      !isSingleYieldOp(genericOp))
+    return std::nullopt;
+
+  // mapping checks.
+  auto mapRange = genericOp.getIndexingMapsArray();
+  if (mapRange.size() != 2 || !mapRange.back().isIdentity() ||
+      !mapRange.front().isPermutation())
+    return std::nullopt;
+
+  SmallVector<int64_t> permutation;
+  auto map = mapRange.front();
+  for (unsigned i = 0; i < map.getNumResults(); ++i) {
+    auto expr = llvm::cast<AffineDimExpr>(map.getResults()[i]);
+    permutation.push_back(expr.getPosition());
+  }
+  return permutation;
 }
 
 //===----------------------------------------------------------------------===//
@@ -106,8 +192,7 @@ static bool
 isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
                                           unsigned arity) {
   // Check all loops are parallel.
-  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
-      genericOp.getNumLoops() < 1)
+  if (!isAllParallel(genericOp) || genericOp.getNumLoops() < 1)
     return false;
 
   // Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4d7b748d7200e2..dfafffce9d9b60 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -259,18 +259,43 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
 //===----------------------------------------------------------------------===//
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
                                                       GenericOp genericOp) {
+  // Copy
   if (isaCopyOpInterface(genericOp)) {
     LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
 
+  // Fill
   if (isaFillOpInterface(genericOp)) {
     LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
 
+  // Broadcast
+  std::optional<SmallVector<int64_t>> equivalentToBroadcast =
+      isaBroadcastOpInterface(genericOp);
+  if (equivalentToBroadcast) {
+    auto dims = *equivalentToBroadcast;
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+        dims);
+    return namedOp;
+  }
+
+  // Transpose
+  std::optional<SmallVector<int64_t>> equivalentToTranspose =
+      isaTransposeOpInterface(genericOp);
+  if (equivalentToTranspose) {
+    auto permutation = *equivalentToTranspose;
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+        permutation);
+    return namedOp;
+  }
+
+  // Elementwise Unary
   if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
     Operation *op = &genericOp.getBody()->front();
     if (isa<math::ExpOp>(op)) {
@@ -279,6 +304,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     }
   }
 
+  // Elementwise Binary
   if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
     bool swap = areBinOpsSwapped(genericOp);
     Operation *op = &genericOp.getBody()->front();
@@ -300,6 +326,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     }
   }
 
+  // Contraction - e.g. matmul
   if (isaContractionOpInterface(genericOp)) {
     return specializeLinalgContractions(rewriter, genericOp);
   }
diff --git a/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
new file mode 100644
index 00000000000000..d6915ec8fbbf6f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: broadcast_first_dimension
+// CHECK-SAME:   %[[A:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) dimensions = [0]
+//
+func.func @broadcast_first_dimension(%A: tensor<?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+   %res = linalg.broadcast ins(%A: tensor<?x?xf32>) outs(%Out: tensor<?x?x?xf32>) dimensions = [0]
+  return %res : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: broadcast_mid_dimension
+// CHECK-SAME:   %[[A:.+]]: tensor<3x5xf32>, %[[Out:.+]]: tensor<3x4x5xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<3x5xf32>) outs(%[[Out]] : tensor<3x4x5xf32>) dimensions = [1]
+//
+func.func @broadcast_mid_dimension(%A: tensor<3x5xf32>, %Out: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
+   %res = linalg.broadcast ins(%A: tensor<3x5xf32>) outs(%Out: tensor<3x4x5xf32>) dimensions = [1]
+  return %res : tensor<3x4x5xf32>
+}
+
+
+// CHECK-LABEL: broadcast_multiple_dimensions
+// CHECK-SAME:   %[[A:.+]]: tensor<4x5x7xf32>, %[[Out:.+]]: tensor<3x4x5x6x7x8x9xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<4x5x7xf32>) outs(%[[Out]] : tensor<3x4x5x6x7x8x9xf32>) dimensions = [0, 3, 5, 6]
+//
+func.func @broadcast_multiple_dimensions(%A: tensor<4x5x7xf32>, %Out: tensor<3x4x5x6x7x8x9xf32>) -> tensor<3x4x5x6x7x8x9xf32> {
+   %res = linalg.broadcast ins(%A: tensor<4x5x7xf32>) outs(%Out: tensor<3x4x5x6x7x8x9xf32>) dimensions = [0,3,5,6]
+  return %res : tensor<3x4x5x6x7x8x9xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
new file mode 100644
index 00000000000000..ebc42c903e6e3e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: linalg_transpose
+// CHECK-SAME:  %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32>
+// CHECK-NOT:   linalg.generic
+// CHECK:  %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0]
+//
+func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
+  %res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0]
+  return %res : tensor<64x16xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 35679db7412f30..31f2f6b1ab513f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -4,18 +4,6 @@
 #map1 = affine_map<(d0, d1) -> (d0)>
 #map2 = affine_map<(d0, d1) -> (d1, d0)>
 
-func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
-  // expected-note @below {{when applied to this op}}
-  linalg.generic {
-    indexing_maps = [#map1, #map], 
-    iterator_types = ["parallel", "parallel"]}
-    ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
-    ^bb0(%in: f32, %out: f32):
-      linalg.yield %in : f32
-  }
-  return
-}
-
 func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
   // expected-note @below {{when applied to this op}}
   linalg.generic {

@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2024

@llvm/pr-subscribers-mlir

Author: Javed Absar (javedabsar1)

Changes

Add support for specializing linalg.broadcast and linalg.transform from generic.
Also, does some refactoring to reuse specialization checks.


Full diff: https://github.com/llvm/llvm-project/pull/110421.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h (+10)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+100-15)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+27)
  • (added) mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir (+32)
  • (added) mlir/test/Dialect/Linalg/roundtrip-transpose.mlir (+11)
  • (modified) mlir/test/Dialect/Linalg/transform-op-specialize.mlir (-12)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 0fcaa96ade4031..6f1c243cc4396d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -120,6 +120,16 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp,
 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
 bool isaCopyOpInterface(LinalgOp linalgOp);
 
+/// Checks whether `genericOp` is semantically equivalent to a
+///  `linalg.broadcast`. Returns broadcast dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaBroadcastOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a
+///  `linalg.transpose`. Returns permuted dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaTransposeOpInterface(GenericOp genericOp);
+
 /// Checks whether a given `genericOp` is semantically equivalent to a single
 /// linalgelementwise unary op. e.g. linalg.exp.
 /// A linalg.generic body could be a series of unary elementwise ops e.g.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 0b5191664a9e2f..5842128091972a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include <algorithm>
+#include <numeric>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -49,18 +50,41 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
   return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
 }
 
+// Returns true if all loops of the linalgOp are parallel
+static bool isAllParallel(LinalgOp op) {
+  return op.getNumParallelLoops() == op.getNumLoops();
+}
+
+// Returns true if and only if linalgOp takes one input and one init.
+static bool isSingleInputOutput(LinalgOp op) {
+  return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1;
+}
+// Returns true if genericOp body is just a yieldOp that yields
+// input operand as result.
+static bool isSingleYieldOp(GenericOp op) {
+  if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1)
+    return false;
+
+  Block *body = op.getBody();
+  if (body->getOperations().size() != 1)
+    return false;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0) != body->getArgument(0))
+    return false;
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // CopyOpInterface implementation
 //===----------------------------------------------------------------------===//
 
 bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
-  // Structural.
-  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+  // Structural and operands
+  if (!isAllParallel(linalgOp) || !isSingleInputOutput(linalgOp))
     return false;
 
-  // Operands and maps.
-  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
-    return false;
   auto mapRange = linalgOp.getIndexingMapsArray();
   if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
       !mapRange.back().isIdentity()) {
@@ -75,8 +99,8 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
 //===----------------------------------------------------------------------===//
 std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
   // Structural.
-  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
-      genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+      !isSingleYieldOp(genericOp))
     return std::nullopt;
 
   // Input should be referenced and init should not.
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
   OpOperand *value = genericOp.getDpsInputOperand(0);
   if (!genericOp.isScalar(value))
     return std::nullopt;
+  return value->get();
+}
 
-  Block *body = genericOp.getBody();
-  if (body->getOperations().size() != 1)
+//===----------------------------------------------------------------------===//
+// BroadcastOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaBroadcastOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+      !isSingleYieldOp(genericOp))
     return std::nullopt;
 
-  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
-  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
-      yieldOp->getOperand(0) != body->getArgument(0))
+  auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
+  auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
+  if (!isa<MemRefType, RankedTensorType>(t0) ||
+      !isa<MemRefType, RankedTensorType>(t1))
     return std::nullopt;
-  return value->get();
+
+  // Check output is identity map. Injective function could also be
+  // a permutation of indices and expressible in linalg.generic but
+  // is not expressible for named broadcast op.
+  auto dstMap = genericOp.getIndexingMapsArray()[1];
+  if (!dstMap.isIdentity())
+    return std::nullopt;
+
+  SmallVector<int64_t> position;
+  auto srcMap = genericOp.getIndexingMapsArray()[0];
+
+  // Check input map is monotonically increasing DimIds.
+  for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
+    auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
+    if (!expr)
+      return std::nullopt;
+    int64_t pos = expr.getPosition();
+    if (i > 0 && pos <= position[i - 1])
+      return std::nullopt;
+    position.push_back(expr.getPosition());
+  }
+
+  SmallVector<int64_t> broadcastedDims;
+  auto numDims = srcMap.getNumDims();
+  for (auto dim : llvm::seq<int64_t>(0, numDims)) {
+    if (!llvm::is_contained(position, dim))
+      broadcastedDims.push_back(dim);
+  }
+  return broadcastedDims;
+}
+
+//===----------------------------------------------------------------------===//
+// TranposeOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaTransposeOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+      !isSingleYieldOp(genericOp))
+    return std::nullopt;
+
+  // mapping checks.
+  auto mapRange = genericOp.getIndexingMapsArray();
+  if (mapRange.size() != 2 || !mapRange.back().isIdentity() ||
+      !mapRange.front().isPermutation())
+    return std::nullopt;
+
+  SmallVector<int64_t> permutation;
+  auto map = mapRange.front();
+  for (unsigned i = 0; i < map.getNumResults(); ++i) {
+    auto expr = llvm::cast<AffineDimExpr>(map.getResults()[i]);
+    permutation.push_back(expr.getPosition());
+  }
+  return permutation;
 }
 
 //===----------------------------------------------------------------------===//
@@ -106,8 +192,7 @@ static bool
 isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
                                           unsigned arity) {
   // Check all loops are parallel.
-  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
-      genericOp.getNumLoops() < 1)
+  if (!isAllParallel(genericOp) || genericOp.getNumLoops() < 1)
     return false;
 
   // Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4d7b748d7200e2..dfafffce9d9b60 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -259,18 +259,43 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
 //===----------------------------------------------------------------------===//
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
                                                       GenericOp genericOp) {
+  // Copy
   if (isaCopyOpInterface(genericOp)) {
     LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
 
+  // Fill
   if (isaFillOpInterface(genericOp)) {
     LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
 
+  // Broadcast
+  std::optional<SmallVector<int64_t>> equivalentToBroadcast =
+      isaBroadcastOpInterface(genericOp);
+  if (equivalentToBroadcast) {
+    auto dims = *equivalentToBroadcast;
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+        dims);
+    return namedOp;
+  }
+
+  // Transpose
+  std::optional<SmallVector<int64_t>> equivalentToTranspose =
+      isaTransposeOpInterface(genericOp);
+  if (equivalentToTranspose) {
+    auto permutation = *equivalentToTranspose;
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+        permutation);
+    return namedOp;
+  }
+
+  // Elementwise Unary
   if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
     Operation *op = &genericOp.getBody()->front();
     if (isa<math::ExpOp>(op)) {
@@ -279,6 +304,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     }
   }
 
+  // Elementwise Binary
   if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
     bool swap = areBinOpsSwapped(genericOp);
     Operation *op = &genericOp.getBody()->front();
@@ -300,6 +326,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     }
   }
 
+  // Contraction - e.g. matmul
   if (isaContractionOpInterface(genericOp)) {
     return specializeLinalgContractions(rewriter, genericOp);
   }
diff --git a/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
new file mode 100644
index 00000000000000..d6915ec8fbbf6f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: broadcast_first_dimension
+// CHECK-SAME:   %[[A:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) dimensions = [0]
+//
+func.func @broadcast_first_dimension(%A: tensor<?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+   %res = linalg.broadcast ins(%A: tensor<?x?xf32>) outs(%Out: tensor<?x?x?xf32>) dimensions = [0]
+  return %res : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: broadcast_mid_dimension
+// CHECK-SAME:   %[[A:.+]]: tensor<3x5xf32>, %[[Out:.+]]: tensor<3x4x5xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<3x5xf32>) outs(%[[Out]] : tensor<3x4x5xf32>) dimensions = [1]
+//
+func.func @broadcast_mid_dimension(%A: tensor<3x5xf32>, %Out: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
+   %res = linalg.broadcast ins(%A: tensor<3x5xf32>) outs(%Out: tensor<3x4x5xf32>) dimensions = [1]
+  return %res : tensor<3x4x5xf32>
+}
+
+
+// CHECK-LABEL: broadcast_multiple_dimensions
+// CHECK-SAME:   %[[A:.+]]: tensor<4x5x7xf32>, %[[Out:.+]]: tensor<3x4x5x6x7x8x9xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<4x5x7xf32>) outs(%[[Out]] : tensor<3x4x5x6x7x8x9xf32>) dimensions = [0, 3, 5, 6]
+//
+func.func @broadcast_multiple_dimensions(%A: tensor<4x5x7xf32>, %Out: tensor<3x4x5x6x7x8x9xf32>) -> tensor<3x4x5x6x7x8x9xf32> {
+   %res = linalg.broadcast ins(%A: tensor<4x5x7xf32>) outs(%Out: tensor<3x4x5x6x7x8x9xf32>) dimensions = [0,3,5,6]
+  return %res : tensor<3x4x5x6x7x8x9xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
new file mode 100644
index 00000000000000..ebc42c903e6e3e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: linalg_transpose
+// CHECK-SAME:  %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32>
+// CHECK-NOT:   linalg.generic
+// CHECK:  %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0]
+//
+func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
+  %res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0]
+  return %res : tensor<64x16xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 35679db7412f30..31f2f6b1ab513f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -4,18 +4,6 @@
 #map1 = affine_map<(d0, d1) -> (d0)>
 #map2 = affine_map<(d0, d1) -> (d1, d0)>
 
-func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
-  // expected-note @below {{when applied to this op}}
-  linalg.generic {
-    indexing_maps = [#map1, #map], 
-    iterator_types = ["parallel", "parallel"]}
-    ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
-    ^bb0(%in: f32, %out: f32):
-      linalg.yield %in : f32
-  }
-  return
-}
-
 func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
   // expected-note @below {{when applied to this op}}
   linalg.generic {

@rengolin
Copy link
Member

@shahidact this PR intersects the "is broadcast" functionality with your PR #104783. Please make sure we use similar ideas across the linalg dialect. Thanks!

@javedabsar1
Copy link
Contributor Author

"is broadcast" functionality with y

Hi @rengolin and shahidact. My generic to naming is based on broadcast . Not a coupling of matmul and transpose/broadcast.

@rengolin please review my patch on merits of current linalg. generic to named has many named ops translation still pending. It is a good feature to have these pass through. If linalg.matmul expands to perm attribute as you propose, the current generic-to-named which translates contractions to matmul variations e.g. linalg.matmul_transpose_a ... I will extend to cover them as well.

@javedabsar1
Copy link
Contributor Author

Ping!

@javedabsar1
Copy link
Contributor Author

Ping!

Ping! @MaheshRavishankar @rengolin @banach-space

@rengolin
Copy link
Member

rengolin commented Oct 7, 2024

"is broadcast" functionality with y

Hi @rengolin and shahidact. My generic to naming is based on broadcast . Not a coupling of matmul and transpose/broadcast.

@rengolin please review my patch on merits of current linalg. generic to named has many named ops translation still pending. It is a good feature to have these pass through. If linalg.matmul expands to perm attribute as you propose, the current generic-to-named which translates contractions to matmul variations e.g. linalg.matmul_transpose_a ... I will extend to cover them as well.

Hey, sorry, I did not mean to review this PR with my tag, just to alert @shahidact to your PR so he can check his own PR against it.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments, otherwise, looks good to me. Thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about the delay with this Javed! Looks good modulo some small suggestions. Also, could you add some negative tests? Thanks!

//===----------------------------------------------------------------------===//
// CopyOpInterface implementation
//===----------------------------------------------------------------------===//

bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
// Structural.
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
// Structural and operands
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I don't understand this comment :) appreciate that you are effectively inheriting this, but let's clarify. Does "Structural and operands" mean "Check the structure (no parallel dims?) and the operands (single input/output?)".

Also trying to make sure I understand 😅

Comment on lines 53 to 77
// Returns true if all loops of the linalgOp are parallel
static bool isAllParallel(LinalgOp op) {
return op.getNumParallelLoops() == op.getNumLoops();
}

// Returns true if and only if linalgOp takes one input and one init.
static bool isSingleInputOutput(LinalgOp op) {
return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1;
}
// Returns true if genericOp body is just a yieldOp that yields
// input operand as result.
static bool isSingleYieldOp(GenericOp op) {
if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1)
return false;

Block *body = op.getBody();
if (body->getOperations().size() != 1)
return false;

auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
yieldOp->getOperand(0) != body->getArgument(0))
return false;
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these belong here? IIUC, the comment above ("Interface utility functions") refers to ODS/TableGen "interfaces" (i.e. none of these is a InterfaceMethod).

Having said that, why not add them to the interface?

Comment on lines 129 to 130
if (!isa<MemRefType, RankedTensorType>(t0) ||
!isa<MemRefType, RankedTensorType>(t1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What else could it be? Perhaps checking for ShapedType would be enough?

!isSingleYieldOp(genericOp))
return std::nullopt;

// mapping checks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit]

Suggested change
// mapping checks.
// Check the maps.

// CHECK-NOT: linalg.generic
// CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0]
//
func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Skip linalg in func name (repeating info already available)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add a 3d, 1d cases? And identity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aded 3D test. Thanks for the suggestion.
1D test on transpose will just get dce-d out.

@javedabsar1
Copy link
Contributor Author

"is broadcast" functionality with y

Hi @rengolin and shahidact. My generic to naming is based on broadcast . Not a coupling of matmul and transpose/broadcast.
@rengolin please review my patch on merits of current linalg. generic to named has many named ops translation still pending. It is a good feature to have these pass through. If linalg.matmul expands to perm attribute as you propose, the current generic-to-named which translates contractions to matmul variations e.g. linalg.matmul_transpose_a ... I will extend to cover them as well.

Hey, sorry, I did not mean to review this PR with my tag, just to alert @shahidact to your PR so he can check his own PR against it.

Thanks Renato for reviewing this. yes i was bit confused about your earlier statement. apologies.

@javedabsar1
Copy link
Contributor Author

Sorry about the delay with this Javed! Looks good modulo some small suggestions. Also, could you add some negative tests? Thanks!

Added. Thanks.

@javedabsar1
Copy link
Contributor Author

Hi @banach-space @rengolin
Thanks for reviewing. I think/hope I have addressed all your review concerns. Please have another look.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my comments, LGTM % the comment re the test file (did you mean to expand roundtrip.mlir?)

Renato has effectively already approved, but I'd wait till tomorrow to make sure he has no further comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a new file instead of re-using roundtrip.mlir? Note that this file is called "roundtrip-broadcast.mlir", but it test both broadcasts and transposes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubled checked transpose and broadcast are in separate file e.g. the linalg.transpose are in roundtrip-transpose.mlir. It may be that in the browser it is appearing mashed up across comments.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@javedabsar1
Copy link
Contributor Author

LGTM, thanks!

Thanks a lot @rengolin and @banach-space for helping improve the patch.

@javedabsar1 javedabsar1 merged commit c13f806 into llvm:main Oct 11, 2024
8 checks passed
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
Add support for specializing linalg.broadcast and linalg.transform from generic. Also, does some refactoring to reuse specialization checks, migrating some common uses to op interface methods.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants