-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Linalg] Add pass to convert linalg.generic back to named ops #95656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Existing `-linalg-generalize-named-ops` lowers named ops to linalg.generic. This patch adds `--linalg-specialize-generic-ops` which converts, where possible, linalg.generic back to named ops. Also, it adds patterns to recognize contractions which can be specialized from linalg.generic to named op: `linalg.{batch_}?matmul{_transpose_(a|b)}?` Patterns to recognize elementwise unary/binary fills/copy were added previously and already exist.
@llvm/pr-subscribers-mlir-linalg Author: Javed Absar (javedabsar1) ChangesExisting This patch adds Patterns to recognize elementwise unary/binary fills/copy were added previously and already exist. Patch is 25.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95656.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 0a4ce8953136d..6a60f7f3ea9f1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -104,6 +104,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
+ let summary = "Convert generic ops back to named ops";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..912f9778a40e4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1395,6 +1395,24 @@ struct LinalgGeneralizationPattern
}
};
+struct LinalgSpecializationPattern
+ : public OpInterfaceRewritePattern<LinalgOp> {
+ using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+ FailureOr<LinalgOp>
+ returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const {
+ auto genericOp = dyn_cast<GenericOp>(op.getOperation());
+ if (!genericOp)
+ return failure();
+ return specializeGenericOp(rewriter, genericOp);
+ }
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(op, rewriter);
+ }
+};
+
/// Vectorization pattern for memref::CopyOp.
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
@@ -1546,6 +1564,11 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns to convert linalg.generic ops to named
+/// ops where possible.
+void populateLinalgGenericOpsSpecializationPatterns(
+ RewritePatternSet &patterns);
+
/// Linalg decompose convolutions patterns
/// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f35ab3b856b4e..8ca76ec43193d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -107,7 +107,7 @@ isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
// Check all loops are parallel, and have only tensor semantics.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
- genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+ genericOp.getNumLoops() < 1)
return false;
// Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2bc4d7fbfadcc..7fac3feba98c9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,12 +11,22 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
+namespace mlir {
+#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
#define DEBUG_TYPE "linalg-specialization"
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
@@ -58,6 +68,175 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
return swapped;
}
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing one of matmul
+// variants but not in a straight-forward way, or the linalg.generic's
+// affine map per operand capture more semantics than is possible with
+// named op (which has implicit map interpreted via name).
+//
+// But a named linalg matmul variant that was 'generalized' should be
+// convertible back to named op here.
+//
+namespace {
+enum class IndexMatchResult {
+ Match = 0, // identity map.
+ Transposed, // transposed map.
+ Mismatch // none of the above.
+};
+
+// Looks at the affine map of an operand and works out if generic accesses
+// the element as identity-map, transposed, or 'cant work out'.
+// This check skips the `offset` batch indices and focuses on the matmul part.
+static IndexMatchResult matchOperandMap(AffineMap m, unsigned offset,
+ unsigned i, unsigned j) {
+ auto expr_ei = dyn_cast<AffineDimExpr>(m.getResults()[offset]);
+ auto expr_ej = dyn_cast<AffineDimExpr>(m.getResults()[offset + 1]);
+ if (!expr_ei || !expr_ej)
+ return IndexMatchResult::Mismatch;
+
+ auto ei = expr_ei.getPosition();
+ auto ej = expr_ej.getPosition();
+
+ if (ei == i && ej == j)
+ return IndexMatchResult::Match;
+
+ if (ei == j && ej == i)
+ return IndexMatchResult::Transposed;
+
+ return IndexMatchResult::Mismatch;
+}
+
+// All the variants `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+// have same number of input/output.
+template <typename Variant>
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<Variant>(
+ op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
+ ValueRange{op.getDpsInits()[0]});
+ return namedOp;
+}
+
+// Converts linalg.generic to named linalg.*matmul* where possible.
+static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+ return failure();
+
+ // Linalg generic contraction can be across multiple axis but for matmul
+ // variants it must be one.
+ if (genericOp.getNumReductionLoops() != 1)
+ return failure();
+
+ // Must be projected permutations.
+ auto mapRange = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(mapRange,
+ [](AffineMap m) { return !m.isProjectedPermutation(); }))
+ return failure();
+
+ // matmul contractions are of the form:
+ // %0 = <elemwise>(permutation-of(cu(block-argument-0),
+ // cu(block-argument-1)))
+ // %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
+ //
+ // where <elemwise> and <reduce> are binary operations constituting a
+ // contraction (in the canonical case, <elemwise> is a multiplication and
+ // <reduce> is an addition). All operands of all operations may be supplied
+ // through a chain of side effect-free unary operations, such as casts,
+ // which is denoted as `cu` above.
+ if (!mlir::linalg::detail::isContractionBody(
+ *genericOp.getBlock(), [](Operation *first, Operation *second) {
+ if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
+ (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
+ (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
+ return true;
+ return false;
+ }))
+ return failure();
+
+ // Finds 2 parallel (m and n) and 1 reduction (k) dimension candidates that
+ // form a matmul subcomputation. These dimensions are such that:
+ // 1. The m dimension is involved in an outer-product along LHS
+ // (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+ // 2. The n dimension is involved in an outer-product along RHS
+ // (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+ // 3. The k dimension appears as a permutation on LHS and RHS.
+ // 4. m, n and k appear only once in any given indexing.
+ // 5. Optional batch dimensions that appear in all operands are captured.
+ auto res = inferContractionDims(genericOp);
+ assert(succeeded(res) && "unexpected failure to infer contraction dims");
+ auto dims = *res;
+
+ // Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
+ if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
+ return failure();
+
+ // Check rank of operands
+ auto indexingMaps = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
+ return m.getResults().size() !=
+ dims.batch.size() + 2 /*two from {m,n,k}*/;
+ }))
+ return failure();
+
+ auto batchSize = dims.batch.size();
+ if (indexingMaps[0].getNumDims() != batchSize + 3) {
+ }
+ if (batchSize) {
+ // Each operand in a linalg generic contraction could express different
+ // permutations for its batch dimension. But for named op it must be
+ // identity since separate maps are not specified.
+ if (llvm::any_of(indexingMaps, [batchSize](AffineMap m) {
+ for (unsigned i = 0; i < batchSize; ++i) {
+ auto expr = dyn_cast<AffineDimExpr>(m.getResults()[i]);
+ if (!expr || expr.getPosition() != i)
+ return true;
+ }
+ return false;
+ }))
+ return failure();
+ }
+
+ auto a = matchOperandMap(indexingMaps[0], batchSize, dims.m[0], dims.k[0]);
+ auto b = matchOperandMap(indexingMaps[1], batchSize, dims.k[0], dims.n[0]);
+ auto c = matchOperandMap(indexingMaps[2], batchSize, dims.m[0], dims.n[0]);
+
+ if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
+ return r == IndexMatchResult::Mismatch;
+ }))
+ return failure();
+
+ if (c != IndexMatchResult::Match ||
+ (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
+ return failure();
+
+ /// Codegen the different matmul variants.
+ if (batchSize) {
+ if (a == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
+ genericOp);
+ if (b == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
+ genericOp);
+ return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+ }
+
+ if (a == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
+ if (b == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
+ return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Categorize linalg generic to named op where possible.
+//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
if (isaCopyOpInterface(genericOp)) {
@@ -100,5 +279,31 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
return namedOp;
}
}
+
+ if (isaContractionOpInterface(genericOp)) {
+ return specializeLinalgContractions(rewriter, genericOp);
+ }
return failure();
}
+
+namespace {
+struct LinalgSpecializeGenericOpsPass
+ : public impl::LinalgSpecializeGenericOpsPassBase<
+ LinalgSpecializeGenericOpsPass> {
+
+ using impl::LinalgSpecializeGenericOpsPassBase<
+ LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
+ void runOnOperation() override;
+};
+} // namespace
+
+void LinalgSpecializeGenericOpsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateLinalgGenericOpsSpecializationPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<LinalgSpecializationPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
new file mode 100644
index 0000000000000..d258d9f518534
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @roundtrip_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: roundtrip_add
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_exp(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+ linalg.exp ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK-LABEL: roundtrip_exp
+// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+
+// -----
+
+func.func @roundtrip_gemm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.add ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_gemm
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.add ins(%[[AB]], %[[C]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
new file mode 100644
index 0000000000000..0ec2dc3a92ec7
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.divf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = math.exp %in : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
new file mode 100644
index 0000000000000..f64953bceefe1
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @specialize_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.mulf %in, %in_0 : f32
+ %1 = arith.addf %out, %0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @specialize_matmul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2 : memref<3x7xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.mulf %in, %in_0 : f32
+ %1 = arith.addf %out, %0 : f32
+ linalg.yield %1 : f32...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Javed Absar (javedabsar1) ChangesExisting This patch adds Patterns to recognize elementwise unary/binary fills/copy were added previously and already exist. Patch is 25.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95656.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 0a4ce8953136d..6a60f7f3ea9f1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -104,6 +104,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
+ let summary = "Convert generic ops back to named ops";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..912f9778a40e4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1395,6 +1395,24 @@ struct LinalgGeneralizationPattern
}
};
+struct LinalgSpecializationPattern
+ : public OpInterfaceRewritePattern<LinalgOp> {
+ using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+ FailureOr<LinalgOp>
+ returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const {
+ auto genericOp = dyn_cast<GenericOp>(op.getOperation());
+ if (!genericOp)
+ return failure();
+ return specializeGenericOp(rewriter, genericOp);
+ }
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(op, rewriter);
+ }
+};
+
/// Vectorization pattern for memref::CopyOp.
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
@@ -1546,6 +1564,11 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns to convert linalg.generic ops to named
+/// ops where possible.
+void populateLinalgGenericOpsSpecializationPatterns(
+ RewritePatternSet &patterns);
+
/// Linalg decompose convolutions patterns
/// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f35ab3b856b4e..8ca76ec43193d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -107,7 +107,7 @@ isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
// Check all loops are parallel, and have only tensor semantics.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
- genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+ genericOp.getNumLoops() < 1)
return false;
// Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2bc4d7fbfadcc..7fac3feba98c9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,12 +11,22 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
+namespace mlir {
+#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
#define DEBUG_TYPE "linalg-specialization"
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
@@ -58,6 +68,175 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
return swapped;
}
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing one of matmul
+// variants but not in a straight-forward way, or the linalg.generic's
+// affine map per operand capture more semantics than is possible with
+// named op (which has implicit map interpreted via name).
+//
+// But a named linalg matmul variant that was 'generalized' should be
+// convertible back to named op here.
+//
+namespace {
+enum class IndexMatchResult {
+ Match = 0, // identity map.
+ Transposed, // transposed map.
+ Mismatch // none of the above.
+};
+
+// Looks at the affine map of an operand and works out if generic accesses
+// the element as identity-map, transposed, or 'cant work out'.
+// This check skips the `offset` batch indices and focuses on the matmul part.
+static IndexMatchResult matchOperandMap(AffineMap m, unsigned offset,
+ unsigned i, unsigned j) {
+ auto expr_ei = dyn_cast<AffineDimExpr>(m.getResults()[offset]);
+ auto expr_ej = dyn_cast<AffineDimExpr>(m.getResults()[offset + 1]);
+ if (!expr_ei || !expr_ej)
+ return IndexMatchResult::Mismatch;
+
+ auto ei = expr_ei.getPosition();
+ auto ej = expr_ej.getPosition();
+
+ if (ei == i && ej == j)
+ return IndexMatchResult::Match;
+
+ if (ei == j && ej == i)
+ return IndexMatchResult::Transposed;
+
+ return IndexMatchResult::Mismatch;
+}
+
+// All the variants `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+// have same number of input/output.
+template <typename Variant>
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<Variant>(
+ op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
+ ValueRange{op.getDpsInits()[0]});
+ return namedOp;
+}
+
+// Converts linalg.generic to named linalg.*matmul* where possible.
+static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+ return failure();
+
+ // Linalg generic contraction can be across multiple axis but for matmul
+ // variants it must be one.
+ if (genericOp.getNumReductionLoops() != 1)
+ return failure();
+
+ // Must be projected permutations.
+ auto mapRange = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(mapRange,
+ [](AffineMap m) { return !m.isProjectedPermutation(); }))
+ return failure();
+
+ // matmul contractions are of the form:
+ // %0 = <elemwise>(permutation-of(cu(block-argument-0),
+ // cu(block-argument-1)))
+ // %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
+ //
+ // where <elemwise> and <reduce> are binary operations constituting a
+ // contraction (in the canonical case, <elemwise> is a multiplication and
+ // <reduce> is an addition). All operands of all operations may be supplied
+ // through a chain of side effect-free unary operations, such as casts,
+ // which is denoted as `cu` above.
+ if (!mlir::linalg::detail::isContractionBody(
+ *genericOp.getBlock(), [](Operation *first, Operation *second) {
+ if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
+ (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
+ (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
+ return true;
+ return false;
+ }))
+ return failure();
+
+ // Finds 2 parallel (m and n) and 1 reduction (k) dimension candidates that
+ // form a matmul subcomputation. These dimensions are such that:
+ // 1. The m dimension is involved in an outer-product along LHS
+ // (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+ // 2. The n dimension is involved in an outer-product along RHS
+ // (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+ // 3. The k dimension appears as a permutation on LHS and RHS.
+ // 4. m, n and k appear only once in any given indexing.
+ // 5. Optional batch dimensions that appear in all operands are captured.
+ auto res = inferContractionDims(genericOp);
+ assert(succeeded(res) && "unexpected failure to infer contraction dims");
+ auto dims = *res;
+
+ // Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
+ if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
+ return failure();
+
+ // Check rank of operands
+ auto indexingMaps = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
+ return m.getResults().size() !=
+ dims.batch.size() + 2 /*two from {m,n,k}*/;
+ }))
+ return failure();
+
+ auto batchSize = dims.batch.size();
+ if (indexingMaps[0].getNumDims() != batchSize + 3) {
+ }
+ if (batchSize) {
+ // Each operand in a linalg generic contraction could express different
+ // permutations for its batch dimension. But for named op it must be
+ // identity since separate maps are not specified.
+ if (llvm::any_of(indexingMaps, [batchSize](AffineMap m) {
+ for (unsigned i = 0; i < batchSize; ++i) {
+ auto expr = dyn_cast<AffineDimExpr>(m.getResults()[i]);
+ if (!expr || expr.getPosition() != i)
+ return true;
+ }
+ return false;
+ }))
+ return failure();
+ }
+
+ auto a = matchOperandMap(indexingMaps[0], batchSize, dims.m[0], dims.k[0]);
+ auto b = matchOperandMap(indexingMaps[1], batchSize, dims.k[0], dims.n[0]);
+ auto c = matchOperandMap(indexingMaps[2], batchSize, dims.m[0], dims.n[0]);
+
+ if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
+ return r == IndexMatchResult::Mismatch;
+ }))
+ return failure();
+
+ if (c != IndexMatchResult::Match ||
+ (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
+ return failure();
+
+ /// Codegen the different matmul variants.
+ if (batchSize) {
+ if (a == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
+ genericOp);
+ if (b == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
+ genericOp);
+ return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+ }
+
+ if (a == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
+ if (b == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
+ return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Categorize linalg generic to named op where possible.
+//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
if (isaCopyOpInterface(genericOp)) {
@@ -100,5 +279,31 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
return namedOp;
}
}
+
+ if (isaContractionOpInterface(genericOp)) {
+ return specializeLinalgContractions(rewriter, genericOp);
+ }
return failure();
}
+
+namespace {
+struct LinalgSpecializeGenericOpsPass
+ : public impl::LinalgSpecializeGenericOpsPassBase<
+ LinalgSpecializeGenericOpsPass> {
+
+ using impl::LinalgSpecializeGenericOpsPassBase<
+ LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
+ void runOnOperation() override;
+};
+} // namespace
+
+void LinalgSpecializeGenericOpsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateLinalgGenericOpsSpecializationPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<LinalgSpecializationPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
new file mode 100644
index 0000000000000..d258d9f518534
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @roundtrip_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: roundtrip_add
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_exp(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+ linalg.exp ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK-LABEL: roundtrip_exp
+// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+
+// -----
+
+func.func @roundtrip_gemm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.add ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_gemm
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.add ins(%[[AB]], %[[C]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
new file mode 100644
index 0000000000000..0ec2dc3a92ec7
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.divf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = math.exp %in : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
new file mode 100644
index 0000000000000..f64953bceefe1
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @specialize_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.mulf %in, %in_0 : f32
+ %1 = arith.addf %out, %0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @specialize_matmul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2 : memref<3x7xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.mulf %in, %in_0 : f32
+ %1 = arith.addf %out, %0 : f32
+ linalg.yield %1 : f32...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking a stab at this utility. It will definitely come in handy.
Quick initial pass for now.
Thanks. |
Thanks for reviewing |
void LinalgSpecializeGenericOpsPass::runOnOperation() { | ||
RewritePatternSet patterns(&getContext()); | ||
populateLinalgGenericOpsSpecializationPatterns(patterns); | ||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please dont ignore the result here. Catch the error and raise failure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ping!
made changes based on comments.
auto expr_ei = dyn_cast<AffineDimExpr>(m.getResults()[offset]); | ||
auto expr_ej = dyn_cast<AffineDimExpr>(m.getResults()[offset + 1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -1395,6 +1395,24 @@ struct LinalgGeneralizationPattern | |||
} | |||
}; | |||
|
|||
struct LinalgSpecializationPattern | |||
: public OpInterfaceRewritePattern<LinalgOp> { | |||
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can just be an OpRewritePattern<GenericOp>
instead of doing a cast from LinalgOp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome suggestion. simplifies. Fixed.
Gentle Ping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you planning to add support for broadcast/reduction/type cast soon?
// contraction (in the canonical case, <elemwise> is a multiplication and | ||
// <reduce> is an addition). All operands of all operations may be supplied | ||
// through a chain of side effect-free unary operations, such as casts, | ||
// which is denoted as `cu` above. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially a daft question - what does cu stand for?
Edit - ah, I see that this is a c&p from https://mlir.llvm.org/doxygen/namespacemlir_1_1linalg_1_1detail.html#a3b205bd5642da72c053d6ca8323970b5? Let's avoid duplication of comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted. I thought it was important to emphasis in this context. but yeah. people can read the LinalgInterfaces directly.
// Finds 2 parallel (m and n) and 1 reduction (k) dimension candidates that | ||
// form a matmul subcomputation. These dimensions are such that: | ||
// 1. The m dimension is involved in an outer-product along LHS | ||
// (i.e. it is a permutation on RES and LHS and does not appear in RHS). | ||
// 2. The n dimension is involved in an outer-product along RHS | ||
// (i.e. it is a permutation on RES and RHS and does not appear in LHS). | ||
// 3. The k dimension appears as a permutation on LHS and RHS. | ||
// 4. m, n and k appear only once in any given indexing. | ||
// 5. Optional batch dimensions that appear in all operands are captured. | ||
auto res = inferContractionDims(genericOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to repeat the original comment: https://mlir.llvm.org/doxygen/namespacemlir_1_1linalg.html#aa2fe10e20900f7c49da8d51805f9e9f0
assert(succeeded(res) && "unexpected failure to infer contraction dims"); | ||
auto dims = *res; | ||
|
||
// Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Surely in practice M,N and K will be =! 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contraction can be along more than one dim as far as linalg contraction is concerned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I understand where the confusion is coming from. For me, "dim sizes" in A = M x K
are M
and K
. Whereas this is checking e.g. the number of "contraction dims corresponding to K as inferred by inferControctionDims
"? So, M
and K
can indeed be "1". Could you clarify in the comment?
Also, are you able to add a negative test to exercise this check?
// Consider the A matrix in `C[M,N] = A[M,K] * B[K,N]`. Below, we | ||
// check whether the index map of A is identity (match), transposed, or | ||
// something completely different (mis-match). | ||
// The naming and explanation is in terms of A, but the function checks | ||
// effectively maps for all A, B, C i.e. <M,N>, <M, K>, <K,N>. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Any chance to drop references to A, B and C? This seems like a very generic utility, but the naming seems to over index on a very specific use-case. For example, why not use expectedPosDim1
and expectedPosDim2
? That would be clearer to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK i tried following your suggestion and then it reads more confusing because I have to then s/exprOfM/exprPosDim1.
The idea was one reads it with one thing in mind, e.g. A in A[m,k] * B[k,n] . Then once one gets it, person gets the hang of it and generalized. Instead of getting confused over long names.
I added more comments to make it easier read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand what you are trying to achieve here, but I still find the documentation confusing 😅 I appreciate that that's a subjective matter, but IMHO a specific example should be use to complement a generic description rather than replace one.
Here's a complete suggestion. Note, I replaced Dim1
and Dim2from my original suggestion with
RowDimand
ColDim` (IIUC, that's what these are):
// Checks whether the input Affine `map` contains two consecutive dims that can
// be interpreted as accessing a 2D matrix. It is assumed that the row and
// column dimension are located next to each other (in this order) and start at
// `rowDimIdx` in the input map.
//
// YOUR SPECIFIC EXAMPLE WITH MATRIX A <<HERE>>
static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
unsigned expectedPosOfRowDim,
unsigned expectedPosOfColDim) {
// Get the matrix multiply indices. They are past the batch indices.
auto exprOfRowDim = map.getResults()[rowDimIdx];
auto exprOfColDim = map.getResults()[rowDimIdx + 1];
// They should be pure dim ids.
if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
exprOfColDim.getKind() != AffineExprKind::DimId)
return IndexMatchResult::Mismatch;
auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
return IndexMatchResult::Match;
if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
return IndexMatchResult::Transposed;
return IndexMatchResult::Mismatch;
}
Feel free to re-use (and/or change). I just feel that we shouldn't be referring to batchSize
and/or "dimension M"/"dimension K" in such a generic hook. For example, from the point of view of this hook it doesn't matter what the batch size is, neither does what "M" and "K" are.
HTH
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Will amend based on your suggestion :)
Thanks @javedabsar1, I have exhausted my questions and will let @banach-space finish the review and approve. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've finally managed to scan the whole thing. Overall LG, thanks! And thanks again for the contribution!
One thing that's still missing - some negative tests. I've made a few specific suggestions inline. You could also add one for e.g. linalg.matvec
and mark it as TODO :)
I've also left a couple of more small suggestion inline, but nothing major.
Also, looks like most files in "mlir/test/Dialect/Linalg" use hyphen (-
) rather than underscore in filenames (_
). For consistency:
- "transform-op-specialize_matmul.mlir" -> "transform-op-specialize-matmul.mlir"
Same suggestion for "transform-op-specialize_elemwise_binary.mlir" and "transform-op-specialize_elemwise_unary.mlir" (the only other 2 files with inconsistent naming)
assert(succeeded(res) && "unexpected failure to infer contraction dims"); | ||
auto dims = *res; | ||
|
||
// Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I understand where the confusion is coming from. For me, "dim sizes" in A = M x K
are M
and K
. Whereas this is checking e.g. the number of "contraction dims corresponding to K as inferred by inferControctionDims
"? So, M
and K
can indeed be "1". Could you clarify in the comment?
Also, are you able to add a negative test to exercise this check?
// %0 = linalg.generic { | ||
// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>, | ||
// affine_map<(d0, d1, d2) -> (d0, 5, d2)>, | ||
// affine_map<(d0, d1, d2) -> (d2, d1, 13)>], | ||
// iterator_types = ["parallel", "parallel", "parallel"]} | ||
// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>) | ||
// outs(%C : tensor<20x20x20xf32>) { | ||
// ^bb0(%a: f32, %b: f32, %c : f32): | ||
// %mul = arith.mulf %a, %b : f32 | ||
// %add = arith.addf %mul, %c : f32 | ||
// linalg.yield %add : f32 | ||
// } -> tensor<20x20x20xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you use it to write a negative test?
})) | ||
return failure(); | ||
|
||
auto batchSize = dims.batch.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, "batch size" would be 2
in this example:
linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
Whereas batchSize = dims.batch.size()
is "the number of batch dims", so 1
. I suggest renaming batchSize
as numOfBatchDims
.
Also, would the number of batch dims be ever != 1?
- renaming batchSize as numOfBatchDims - example of multiple contraction dims corresponding to K as inferred by inferControctionDims - add linalg.matvec and mark it as TODO - use hyphen (-) rather than underscore in filenames (_). - implement banach suggestion to replace M K in explanation with expectedPosOfRowDim etc
Made all changes requested as much as possible. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for addressing my comments and for seeing this through - super handy!
I've left one nit - feel free to ignore. I believe that you've already addressed comments from other reviewers, so this can be merged. Please don't forget to fix the formatting ;-)
Thanks again!
#mapA = affine_map<(m, n, k1, k2) -> (m, k1, k2)> | ||
#mapB = affine_map<(m, n, k1, k2) -> (k2, k1, n)> | ||
#mapC = affine_map<(m, n, k1, k2) -> (m, n)> | ||
func.func @op_multi_reduction(%A: tensor<10x20x30xf32>, %B: tensor<30x20x40xf32>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] You could add "negative" to the test name (e.g. @negative_op_multi_reduction
)
…lvm#95656) Add a new mlir-opt pass `--linalg-specialize-generic-ops` which lifts generic, where possible, to linalg named ops. Much like `-linalg-generalize-named-ops` lowers named ops to linalg.generic . Also add patterns to recognize contractions which can be specialized from linalg.generic to named op: `linalg.{batch_}?matmul{_transpose_(a|b)}?`
@javedabsar1 Does this pass change linalg generic which is just copying tensor from input to output, to linalg.copy? |
Add a new pass
--linalg-specialize-generic-ops
which lifts, where possible, linalg.generic to named ops.Much like
-linalg-generalize-named-ops
lowers named ops to linalg.generic .Also add patterns to recognize contractions which can be specialized from linalg.generic to named op:
linalg.{batch_}?matmul{_transpose_(a|b)}?
Patterns to recognize elementwise unary/binary fills/copy were added previously and already exist.