Skip to content

[MLIR][Linalg] Add pass to convert linalg.generic back to named ops #95656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 30, 2024

Conversation

javedabsar1
Copy link
Contributor

@javedabsar1 javedabsar1 commented Jun 15, 2024

Add a new pass --linalg-specialize-generic-ops which lifts, where possible, linalg.generic to named ops.
Much like -linalg-generalize-named-ops lowers named ops to linalg.generic .
Also add patterns to recognize contractions which can be specialized from linalg.generic to named op: linalg.{batch_}?matmul{_transpose_(a|b)}?
Patterns to recognize elementwise unary/binary fills/copy were added previously and already exist.

Existing `-linalg-generalize-named-ops` lowers named ops to
linalg.generic. This patch adds `--linalg-specialize-generic-ops`
which converts, where possible, linalg.generic back to named ops.
Also, it adds patterns to recognize contractions which can be
specialized from linalg.generic to named op:
 `linalg.{batch_}?matmul{_transpose_(a|b)}?`
Patterns to recognize elementwise unary/binary fills/copy
were added previously and already exist.
@llvmbot
Copy link
Member

llvmbot commented Jun 15, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Javed Absar (javedabsar1)

Changes

Existing -linalg-generalize-named-ops lowers named ops to linalg.generic.

This patch adds --linalg-specialize-generic-ops which converts, where possible, linalg.generic back to named ops. Also, it adds patterns to recognize contractions which can be specialized from linalg.generic to named op:
linalg.{batch_}?matmul{_transpose_(a|b)}?

Patterns to recognize elementwise unary/binary fills/copy were added previously and already exist.


Patch is 25.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95656.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+5)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+23)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+205)
  • (added) mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir (+49)
  • (added) mlir/test/Dialect/Linalg/specialize-generic-ops.mlir (+37)
  • (added) mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir (+148)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 0a4ce8953136d..6a60f7f3ea9f1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -104,6 +104,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
+  let summary = "Convert generic ops back to named ops";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
   let summary = "Detensorize linalg ops";
   let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..912f9778a40e4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1395,6 +1395,24 @@ struct LinalgGeneralizationPattern
   }
 };
 
+struct LinalgSpecializationPattern
+    : public OpInterfaceRewritePattern<LinalgOp> {
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+  FailureOr<LinalgOp>
+  returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const {
+    auto genericOp = dyn_cast<GenericOp>(op.getOperation());
+    if (!genericOp)
+      return failure();
+    return specializeGenericOp(rewriter, genericOp);
+  }
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
+};
+
 /// Vectorization pattern for memref::CopyOp.
 struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
   using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
@@ -1546,6 +1564,11 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
 /// linalg.generic ops.
 void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns to convert linalg.generic ops to named
+/// ops where possible.
+void populateLinalgGenericOpsSpecializationPatterns(
+    RewritePatternSet &patterns);
+
 /// Linalg decompose convolutions patterns
 
 /// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f35ab3b856b4e..8ca76ec43193d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -107,7 +107,7 @@ isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
                                           unsigned arity) {
   // Check all loops are parallel, and have only tensor semantics.
   if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
-      genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+      genericOp.getNumLoops() < 1)
     return false;
 
   // Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2bc4d7fbfadcc..7fac3feba98c9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,12 +11,22 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
 
+namespace mlir {
+#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
 #define DEBUG_TYPE "linalg-specialization"
 
 #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)                                \
@@ -58,6 +68,175 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
   return swapped;
 }
 
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+//    ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing one of matmul
+// variants but not in a straight-forward way, or the linalg.generic's
+// affine map per operand capture more semantics than is possible with
+// named op (which has implicit map interpreted via name).
+//
+// But a named linalg matmul variant that was 'generalized' should be
+// convertible back to named op here.
+//
+namespace {
+enum class IndexMatchResult {
+  Match = 0,  // identity map.
+  Transposed, // transposed map.
+  Mismatch    // none of the above.
+};
+
+// Looks at the affine map of an operand and works out if generic accesses
+// the element as identity-map, transposed, or 'cant work out'.
+// This check skips the `offset` batch indices and focuses on the matmul part.
+static IndexMatchResult matchOperandMap(AffineMap m, unsigned offset,
+                                        unsigned i, unsigned j) {
+  auto expr_ei = dyn_cast<AffineDimExpr>(m.getResults()[offset]);
+  auto expr_ej = dyn_cast<AffineDimExpr>(m.getResults()[offset + 1]);
+  if (!expr_ei || !expr_ej)
+    return IndexMatchResult::Mismatch;
+
+  auto ei = expr_ei.getPosition();
+  auto ej = expr_ej.getPosition();
+
+  if (ei == i && ej == j)
+    return IndexMatchResult::Match;
+
+  if (ei == j && ej == i)
+    return IndexMatchResult::Transposed;
+
+  return IndexMatchResult::Mismatch;
+}
+
+//  All the variants `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+//  have same number of input/output.
+template <typename Variant>
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+  LinalgOp namedOp = rewriter.replaceOpWithNewOp<Variant>(
+      op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
+      ValueRange{op.getDpsInits()[0]});
+  return namedOp;
+}
+
+// Converts linalg.generic to named linalg.*matmul* where possible.
+static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
+                                                        GenericOp genericOp) {
+  if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+    return failure();
+
+  // Linalg generic contraction can be across multiple axis but for matmul
+  // variants it must be one.
+  if (genericOp.getNumReductionLoops() != 1)
+    return failure();
+
+  // Must be projected permutations.
+  auto mapRange = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(mapRange,
+                   [](AffineMap m) { return !m.isProjectedPermutation(); }))
+    return failure();
+
+  //  matmul contractions are of the form:
+  //  %0 = <elemwise>(permutation-of(cu(block-argument-0),
+  //                                 cu(block-argument-1)))
+  //  %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
+  //
+  //  where <elemwise> and <reduce> are binary operations constituting a
+  //  contraction (in the canonical case, <elemwise> is a multiplication and
+  //  <reduce> is an addition). All operands of all operations may be supplied
+  //  through a chain of side effect-free unary operations, such as casts,
+  //  which is denoted as `cu` above.
+  if (!mlir::linalg::detail::isContractionBody(
+          *genericOp.getBlock(), [](Operation *first, Operation *second) {
+            if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
+                (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
+                (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
+              return true;
+            return false;
+          }))
+    return failure();
+
+  // Finds 2 parallel (m and n) and 1 reduction (k) dimension candidates that
+  // form a matmul subcomputation. These dimensions are such that:
+  //   1. The m dimension is involved in an outer-product along LHS
+  //      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+  //   2. The n dimension is involved in an outer-product along RHS
+  //      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+  //   3. The k dimension appears as a permutation on LHS and RHS.
+  //   4. m, n and k appear only once in any given indexing.
+  //   5. Optional batch dimensions that appear in all operands are captured.
+  auto res = inferContractionDims(genericOp);
+  assert(succeeded(res) && "unexpected failure to infer contraction dims");
+  auto dims = *res;
+
+  // Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
+  if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
+    return failure();
+
+  // Check rank of operands
+  auto indexingMaps = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
+        return m.getResults().size() !=
+               dims.batch.size() + 2 /*two from {m,n,k}*/;
+      }))
+    return failure();
+
+  auto batchSize = dims.batch.size();
+  if (indexingMaps[0].getNumDims() != batchSize + 3) {
+  }
+  if (batchSize) {
+    // Each operand in a linalg generic contraction  could express different
+    // permutations for its batch dimension. But for named op it must be
+    // identity since separate maps are not specified.
+    if (llvm::any_of(indexingMaps, [batchSize](AffineMap m) {
+          for (unsigned i = 0; i < batchSize; ++i) {
+            auto expr = dyn_cast<AffineDimExpr>(m.getResults()[i]);
+            if (!expr || expr.getPosition() != i)
+              return true;
+          }
+          return false;
+        }))
+      return failure();
+  }
+
+  auto a = matchOperandMap(indexingMaps[0], batchSize, dims.m[0], dims.k[0]);
+  auto b = matchOperandMap(indexingMaps[1], batchSize, dims.k[0], dims.n[0]);
+  auto c = matchOperandMap(indexingMaps[2], batchSize, dims.m[0], dims.n[0]);
+
+  if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
+        return r == IndexMatchResult::Mismatch;
+      }))
+    return failure();
+
+  if (c != IndexMatchResult::Match ||
+      (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
+    return failure();
+
+  /// Codegen the different matmul variants.
+  if (batchSize) {
+    if (a == IndexMatchResult::Transposed)
+      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
+                                                               genericOp);
+    if (b == IndexMatchResult::Transposed)
+      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
+                                                               genericOp);
+    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+  }
+
+  if (a == IndexMatchResult::Transposed)
+    return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
+  if (b == IndexMatchResult::Transposed)
+    return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
+  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Categorize linalg generic to named op where possible.
+//===----------------------------------------------------------------------===//
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
                                                       GenericOp genericOp) {
   if (isaCopyOpInterface(genericOp)) {
@@ -100,5 +279,31 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
       return namedOp;
     }
   }
+
+  if (isaContractionOpInterface(genericOp)) {
+    return specializeLinalgContractions(rewriter, genericOp);
+  }
   return failure();
 }
+
+namespace {
+struct LinalgSpecializeGenericOpsPass
+    : public impl::LinalgSpecializeGenericOpsPassBase<
+          LinalgSpecializeGenericOpsPass> {
+
+  using impl::LinalgSpecializeGenericOpsPassBase<
+      LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
+  void runOnOperation() override;
+};
+} // namespace
+
+void LinalgSpecializeGenericOpsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  populateLinalgGenericOpsSpecializationPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<LinalgSpecializationPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
new file mode 100644
index 0000000000000..d258d9f518534
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @roundtrip_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: roundtrip_add
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>,  %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_exp(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.exp ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK-LABEL: roundtrip_exp
+// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+
+// -----
+
+func.func @roundtrip_gemm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = linalg.add ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_gemm
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.add ins(%[[AB]], %[[C]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
new file mode 100644
index 0000000000000..0ec2dc3a92ec7
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic 
+         {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+         ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.divf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = math.exp %in : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
new file mode 100644
index 0000000000000..f64953bceefe1
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @specialize_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+          ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %0 = arith.mulf %in, %in_0 : f32
+      %1 = arith.addf %out, %0 : f32
+      linalg.yield %1 : f32
+    } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @specialize_matmul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.generic
+     {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+     ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2 : memref<3x7xf32>) {
+      ^bb0(%in: f32, %in_0: f32, %out: f32):
+       %0 = arith.mulf %in, %in_0 : f32
+       %1 = arith.addf %out, %0 : f32
+       linalg.yield %1 : f32...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 15, 2024

@llvm/pr-subscribers-mlir

Author: Javed Absar (javedabsar1)

Changes

Existing -linalg-generalize-named-ops lowers named ops to linalg.generic.

This patch adds --linalg-specialize-generic-ops which converts, where possible, linalg.generic back to named ops. Also, it adds patterns to recognize contractions which can be specialized from linalg.generic to named op:
linalg.{batch_}?matmul{_transpose_(a|b)}?

Patterns to recognize elementwise unary/binary fills/copy were added previously and already exist.


Patch is 25.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95656.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+5)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+23)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+205)
  • (added) mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir (+49)
  • (added) mlir/test/Dialect/Linalg/specialize-generic-ops.mlir (+37)
  • (added) mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir (+148)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 0a4ce8953136d..6a60f7f3ea9f1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -104,6 +104,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
+  let summary = "Convert generic ops back to named ops";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
   let summary = "Detensorize linalg ops";
   let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..912f9778a40e4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1395,6 +1395,24 @@ struct LinalgGeneralizationPattern
   }
 };
 
+struct LinalgSpecializationPattern
+    : public OpInterfaceRewritePattern<LinalgOp> {
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+  FailureOr<LinalgOp>
+  returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const {
+    auto genericOp = dyn_cast<GenericOp>(op.getOperation());
+    if (!genericOp)
+      return failure();
+    return specializeGenericOp(rewriter, genericOp);
+  }
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
+};
+
 /// Vectorization pattern for memref::CopyOp.
 struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
   using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
@@ -1546,6 +1564,11 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
 /// linalg.generic ops.
 void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns to convert linalg.generic ops to named
+/// ops where possible.
+void populateLinalgGenericOpsSpecializationPatterns(
+    RewritePatternSet &patterns);
+
 /// Linalg decompose convolutions patterns
 
 /// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f35ab3b856b4e..8ca76ec43193d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -107,7 +107,7 @@ isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
                                           unsigned arity) {
   // Check all loops are parallel, and have only tensor semantics.
   if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
-      genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+      genericOp.getNumLoops() < 1)
     return false;
 
   // Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2bc4d7fbfadcc..7fac3feba98c9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,12 +11,22 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
 
+namespace mlir {
+#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
 #define DEBUG_TYPE "linalg-specialization"
 
 #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)                                \
@@ -58,6 +68,175 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
   return swapped;
 }
 
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+//    ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing one of matmul
+// variants but not in a straight-forward way, or the linalg.generic's
+// affine map per operand capture more semantics than is possible with
+// named op (which has implicit map interpreted via name).
+//
+// But a named linalg matmul variant that was 'generalized' should be
+// convertible back to named op here.
+//
+namespace {
+enum class IndexMatchResult {
+  Match = 0,  // identity map.
+  Transposed, // transposed map.
+  Mismatch    // none of the above.
+};
+
+// Looks at the affine map of an operand and works out if generic accesses
+// the element as identity-map, transposed, or 'cant work out'.
+// This check skips the `offset` batch indices and focuses on the matmul part.
+static IndexMatchResult matchOperandMap(AffineMap m, unsigned offset,
+                                        unsigned i, unsigned j) {
+  auto expr_ei = dyn_cast<AffineDimExpr>(m.getResults()[offset]);
+  auto expr_ej = dyn_cast<AffineDimExpr>(m.getResults()[offset + 1]);
+  if (!expr_ei || !expr_ej)
+    return IndexMatchResult::Mismatch;
+
+  auto ei = expr_ei.getPosition();
+  auto ej = expr_ej.getPosition();
+
+  if (ei == i && ej == j)
+    return IndexMatchResult::Match;
+
+  if (ei == j && ej == i)
+    return IndexMatchResult::Transposed;
+
+  return IndexMatchResult::Mismatch;
+}
+
+//  All the variants `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+//  have same number of input/output.
+template <typename Variant>
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+  LinalgOp namedOp = rewriter.replaceOpWithNewOp<Variant>(
+      op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
+      ValueRange{op.getDpsInits()[0]});
+  return namedOp;
+}
+
+// Converts linalg.generic to named linalg.*matmul* where possible.
+static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
+                                                        GenericOp genericOp) {
+  if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+    return failure();
+
+  // Linalg generic contraction can be across multiple axis but for matmul
+  // variants it must be one.
+  if (genericOp.getNumReductionLoops() != 1)
+    return failure();
+
+  // Must be projected permutations.
+  auto mapRange = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(mapRange,
+                   [](AffineMap m) { return !m.isProjectedPermutation(); }))
+    return failure();
+
+  //  matmul contractions are of the form:
+  //  %0 = <elemwise>(permutation-of(cu(block-argument-0),
+  //                                 cu(block-argument-1)))
+  //  %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
+  //
+  //  where <elemwise> and <reduce> are binary operations constituting a
+  //  contraction (in the canonical case, <elemwise> is a multiplication and
+  //  <reduce> is an addition). All operands of all operations may be supplied
+  //  through a chain of side effect-free unary operations, such as casts,
+  //  which is denoted as `cu` above.
+  if (!mlir::linalg::detail::isContractionBody(
+          *genericOp.getBlock(), [](Operation *first, Operation *second) {
+            if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
+                (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
+                (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
+              return true;
+            return false;
+          }))
+    return failure();
+
+  // Finds 2 parallel (m and n) and 1 reduction (k) dimension candidates that
+  // form a matmul subcomputation. These dimensions are such that:
+  //   1. The m dimension is involved in an outer-product along LHS
+  //      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+  //   2. The n dimension is involved in an outer-product along RHS
+  //      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+  //   3. The k dimension appears as a permutation on LHS and RHS.
+  //   4. m, n and k appear only once in any given indexing.
+  //   5. Optional batch dimensions that appear in all operands are captured.
+  auto res = inferContractionDims(genericOp);
+  assert(succeeded(res) && "unexpected failure to infer contraction dims");
+  auto dims = *res;
+
+  // Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
+  if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
+    return failure();
+
+  // Check rank of operands
+  auto indexingMaps = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
+        return m.getResults().size() !=
+               dims.batch.size() + 2 /*two from {m,n,k}*/;
+      }))
+    return failure();
+
+  auto batchSize = dims.batch.size();
+  if (indexingMaps[0].getNumDims() != batchSize + 3) {
+  }
+  if (batchSize) {
+    // Each operand in a linalg generic contraction  could express different
+    // permutations for its batch dimension. But for named op it must be
+    // identity since separate maps are not specified.
+    if (llvm::any_of(indexingMaps, [batchSize](AffineMap m) {
+          for (unsigned i = 0; i < batchSize; ++i) {
+            auto expr = dyn_cast<AffineDimExpr>(m.getResults()[i]);
+            if (!expr || expr.getPosition() != i)
+              return true;
+          }
+          return false;
+        }))
+      return failure();
+  }
+
+  auto a = matchOperandMap(indexingMaps[0], batchSize, dims.m[0], dims.k[0]);
+  auto b = matchOperandMap(indexingMaps[1], batchSize, dims.k[0], dims.n[0]);
+  auto c = matchOperandMap(indexingMaps[2], batchSize, dims.m[0], dims.n[0]);
+
+  if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
+        return r == IndexMatchResult::Mismatch;
+      }))
+    return failure();
+
+  if (c != IndexMatchResult::Match ||
+      (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
+    return failure();
+
+  /// Codegen the different matmul variants.
+  if (batchSize) {
+    if (a == IndexMatchResult::Transposed)
+      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
+                                                               genericOp);
+    if (b == IndexMatchResult::Transposed)
+      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
+                                                               genericOp);
+    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+  }
+
+  if (a == IndexMatchResult::Transposed)
+    return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
+  if (b == IndexMatchResult::Transposed)
+    return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
+  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Categorize linalg generic to named op where possible.
+//===----------------------------------------------------------------------===//
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
                                                       GenericOp genericOp) {
   if (isaCopyOpInterface(genericOp)) {
@@ -100,5 +279,31 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
       return namedOp;
     }
   }
+
+  if (isaContractionOpInterface(genericOp)) {
+    return specializeLinalgContractions(rewriter, genericOp);
+  }
   return failure();
 }
+
+namespace {
+struct LinalgSpecializeGenericOpsPass
+    : public impl::LinalgSpecializeGenericOpsPassBase<
+          LinalgSpecializeGenericOpsPass> {
+
+  using impl::LinalgSpecializeGenericOpsPassBase<
+      LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
+  void runOnOperation() override;
+};
+} // namespace
+
+void LinalgSpecializeGenericOpsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  populateLinalgGenericOpsSpecializationPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<LinalgSpecializationPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
new file mode 100644
index 0000000000000..d258d9f518534
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @roundtrip_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: roundtrip_add
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>,  %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_exp(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.exp ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK-LABEL: roundtrip_exp
+// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+
+// -----
+
+func.func @roundtrip_gemm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = linalg.add ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_gemm
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.add ins(%[[AB]], %[[C]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
new file mode 100644
index 0000000000000..0ec2dc3a92ec7
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic 
+         {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+         ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.divf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = math.exp %in : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
new file mode 100644
index 0000000000000..f64953bceefe1
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @specialize_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+          ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %0 = arith.mulf %in, %in_0 : f32
+      %1 = arith.addf %out, %0 : f32
+      linalg.yield %1 : f32
+    } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @specialize_matmul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.generic
+     {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+     ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2 : memref<3x7xf32>) {
+      ^bb0(%in: f32, %in_0: f32, %out: f32):
+       %0 = arith.mulf %in, %in_0 : f32
+       %1 = arith.addf %out, %0 : f32
+       linalg.yield %1 : f32...
[truncated]

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a stab at this utility. It will definitely come in handy.
Quick initial pass for now.

@javedabsar1
Copy link
Contributor Author

Thanks for taking a stab at this utility. It will definitely come in handy. Quick initial pass for now.

Thanks.

@javedabsar1
Copy link
Contributor Author

Thanks for reviewing

void LinalgSpecializeGenericOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLinalgGenericOpsSpecializationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please dont ignore the result here. Catch the error and raise failure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping!
made changes based on comments.

Comment on lines 97 to 98
auto expr_ei = dyn_cast<AffineDimExpr>(m.getResults()[offset]);
auto expr_ej = dyn_cast<AffineDimExpr>(m.getResults()[offset + 1]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -1395,6 +1395,24 @@ struct LinalgGeneralizationPattern
}
};

struct LinalgSpecializationPattern
: public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be an OpRewritePattern<GenericOp> instead of doing a cast from LinalgOp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome suggestion. simplifies. Fixed.

@javedabsar1
Copy link
Contributor Author

Gentle Ping.

@javedabsar1 javedabsar1 requested a review from qedawkins June 24, 2024 23:26
Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you planning to add support for broadcast/reduction/type cast soon?

// contraction (in the canonical case, <elemwise> is a multiplication and
// <reduce> is an addition). All operands of all operations may be supplied
// through a chain of side effect-free unary operations, such as casts,
// which is denoted as `cu` above.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially a daft question - what does cu stand for?

Edit - ah, I see that this is a c&p from https://mlir.llvm.org/doxygen/namespacemlir_1_1linalg_1_1detail.html#a3b205bd5642da72c053d6ca8323970b5? Let's avoid duplication of comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted. I thought it was important to emphasis in this context. but yeah. people can read the LinalgInterfaces directly.

Comment on lines 179 to 188
// Finds 2 parallel (m and n) and 1 reduction (k) dimension candidates that
// form a matmul subcomputation. These dimensions are such that:
// 1. The m dimension is involved in an outer-product along LHS
// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
// 2. The n dimension is involved in an outer-product along RHS
// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
// 3. The k dimension appears as a permutation on LHS and RHS.
// 4. m, n and k appear only once in any given indexing.
// 5. Optional batch dimensions that appear in all operands are captured.
auto res = inferContractionDims(genericOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert(succeeded(res) && "unexpected failure to infer contraction dims");
auto dims = *res;

// Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Surely in practice M,N and K will be =! 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contraction can be along more than one dim as far as linalg contraction is concerned.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I understand where the confusion is coming from. For me, "dim sizes" in A = M x K are M and K. Whereas this is checking e.g. the number of "contraction dims corresponding to K as inferred by inferControctionDims"? So, M and K can indeed be "1". Could you clarify in the comment?

Also, are you able to add a negative test to exercise this check?

Comment on lines 103 to 107
// Consider the A matrix in `C[M,N] = A[M,K] * B[K,N]`. Below, we
// check whether the index map of A is identity (match), transposed, or
// something completely different (mis-match).
// The naming and explanation is in terms of A, but the function checks
// effectively maps for all A, B, C i.e. <M,N>, <M, K>, <K,N>.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Any chance to drop references to A, B and C? This seems like a very generic utility, but the naming seems to over index on a very specific use-case. For example, why not use expectedPosDim1 and expectedPosDim2? That would be clearer to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK i tried following your suggestion and then it reads more confusing because I have to then s/exprOfM/exprPosDim1.
The idea was one reads it with one thing in mind, e.g. A in A[m,k] * B[k,n] . Then once one gets it, person gets the hang of it and generalized. Instead of getting confused over long names.
I added more comments to make it easier read.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand what you are trying to achieve here, but I still find the documentation confusing 😅 I appreciate that that's a subjective matter, but IMHO a specific example should be use to complement a generic description rather than replace one.

Here's a complete suggestion. Note, I replaced Dim1 and Dim2from my original suggestion withRowDimandColDim` (IIUC, that's what these are):

// Checks whether the input Affine `map` contains two consecutive dims that can
// be interpreted as accessing a 2D matrix. It is assumed that the row and
// column dimension are located next to each other (in this order) and start at
// `rowDimIdx` in the input map.
//
// YOUR SPECIFIC EXAMPLE WITH MATRIX A <<HERE>>
static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
                                        unsigned expectedPosOfRowDim,
                                        unsigned expectedPosOfColDim) {
  // Get the matrix multiply indices. They are past the batch indices.
  auto exprOfRowDim = map.getResults()[rowDimIdx];
  auto exprOfColDim = map.getResults()[rowDimIdx + 1];

  // They should be pure dim ids.
  if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
      exprOfColDim.getKind() != AffineExprKind::DimId)
    return IndexMatchResult::Mismatch;

  auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
  auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();

  if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
    return IndexMatchResult::Match;

  if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
    return IndexMatchResult::Transposed;

  return IndexMatchResult::Mismatch;
}

Feel free to re-use (and/or change). I just feel that we shouldn't be referring to batchSize and/or "dimension M"/"dimension K" in such a generic hook. For example, from the point of view of this hook it doesn't matter what the batch size is, neither does what "M" and "K" are.

HTH

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Will amend based on your suggestion :)

@rengolin
Copy link
Member

Thanks @javedabsar1, I have exhausted my questions and will let @banach-space finish the review and approve.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've finally managed to scan the whole thing. Overall LG, thanks! And thanks again for the contribution!

One thing that's still missing - some negative tests. I've made a few specific suggestions inline. You could also add one for e.g. linalg.matvec and mark it as TODO :)

I've also left a couple of more small suggestion inline, but nothing major.

Also, looks like most files in "mlir/test/Dialect/Linalg" use hyphen (-) rather than underscore in filenames (_). For consistency:

  • "transform-op-specialize_matmul.mlir" -> "transform-op-specialize-matmul.mlir"

Same suggestion for "transform-op-specialize_elemwise_binary.mlir" and "transform-op-specialize_elemwise_unary.mlir" (the only other 2 files with inconsistent naming)

assert(succeeded(res) && "unexpected failure to infer contraction dims");
auto dims = *res;

// Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I understand where the confusion is coming from. For me, "dim sizes" in A = M x K are M and K. Whereas this is checking e.g. the number of "contraction dims corresponding to K as inferred by inferControctionDims"? So, M and K can indeed be "1". Could you clarify in the comment?

Also, are you able to add a negative test to exercise this check?

Comment on lines +80 to +91
// %0 = linalg.generic {
// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
// iterator_types = ["parallel", "parallel", "parallel"]}
// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
// outs(%C : tensor<20x20x20xf32>) {
// ^bb0(%a: f32, %b: f32, %c : f32):
// %mul = arith.mulf %a, %b : f32
// %add = arith.addf %mul, %c : f32
// linalg.yield %add : f32
// } -> tensor<20x20x20xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use it to write a negative test?

}))
return failure();

auto batchSize = dims.batch.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, "batch size" would be 2 in this example:

 linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>

Whereas batchSize = dims.batch.size() is "the number of batch dims", so 1. I suggest renaming batchSize as numOfBatchDims.

Also, would the number of batch dims be ever != 1?

- renaming batchSize as numOfBatchDims
- example of multiple contraction dims corresponding to K as inferred by inferControctionDims
- add linalg.matvec and mark it as TODO
- use hyphen (-) rather than underscore in filenames (_).
- implement banach suggestion to replace M K in explanation with expectedPosOfRowDim etc
@javedabsar1
Copy link
Contributor Author

Made all changes requested as much as possible.

@javedabsar1 javedabsar1 requested a review from banach-space June 28, 2024 23:04
Copy link

github-actions bot commented Jun 28, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for addressing my comments and for seeing this through - super handy!

I've left one nit - feel free to ignore. I believe that you've already addressed comments from other reviewers, so this can be merged. Please don't forget to fix the formatting ;-)

Thanks again!

#mapA = affine_map<(m, n, k1, k2) -> (m, k1, k2)>
#mapB = affine_map<(m, n, k1, k2) -> (k2, k1, n)>
#mapC = affine_map<(m, n, k1, k2) -> (m, n)>
func.func @op_multi_reduction(%A: tensor<10x20x30xf32>, %B: tensor<30x20x40xf32>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] You could add "negative" to the test name (e.g. @negative_op_multi_reduction)

@javedabsar1 javedabsar1 merged commit 3efac5c into llvm:main Jun 30, 2024
7 checks passed
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
…lvm#95656)

Add a new mlir-opt  pass `--linalg-specialize-generic-ops` which lifts generic,
where possible, to linalg named ops.
Much like `-linalg-generalize-named-ops` lowers named ops to linalg.generic .
Also add patterns to recognize contractions which can be specialized from 
linalg.generic to named op: `linalg.{batch_}?matmul{_transpose_(a|b)}?`
@11happy
Copy link
Contributor

11happy commented Jan 22, 2025

@javedabsar1 Does this pass change linalg generic which is just copying tensor from input to output, to linalg.copy?
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants