-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Linalg] Introduce SpecializeOp #70326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: lorenzo chelini (chelini) ChangesIntroduce an operation to specialize linalg.generics, for example, detecting a linalg.generic that is semantically equivalent to a linalg.copy and replacing the former with the latter. After code generation, it is helpful to lower named operations to vendor-optimized libraries. Full diff: https://github.com/llvm/llvm-project/pull/70326.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1ff88d036bc036c..2d86d443a28ebbb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -390,6 +390,42 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
}];
}
+//===----------------------------------------------------------------------===//
+// SpecializeOp
+//===----------------------------------------------------------------------===//
+
+def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Transforms a generic operation into the equivalent named form.
+
+ #### Return modes
+
+ This operation ignores non-Linalg ops and drops them in the return.
+ If all the operations referred to by the `target` handle specialize
+ properly, the transform succeeds. Otherwise the transform silently fails.
+ The return handle points to only the subset of successfully produced
+ equivalent named operations, which can be empty or contain the original
+ ops if they were already in named form.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+ let assemblyFormat =
+ "$target attr-dict `:` "
+ "custom<SemiFunctionType>(type($target), type($transformed))";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::linalg::LinalgOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fbe2923c710aabb..122f73562852101 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -668,6 +668,11 @@ FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
LinalgOp namedOp);
+/// Create a namedOp from the given GenericOp and replace the GenericOp.
+/// Currently we can specialize only trivial linalg copy operations.
+FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp);
+
/// Create a new buffer using the `allocationFn` provided. The size of this
/// buffer is the smallest constant bounding size along each dimension that
/// can be computed for the size of the result of `subView`. Returns the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8508507871d0c6c..87be3bb85b6e788 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1018,6 +1018,30 @@ transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
return emitDefaultSilenceableFailure(target);
}
+//===----------------------------------------------------------------------===//
+// SpecializeOp
+//===----------------------------------------------------------------------===/
+
+DiagnosedSilenceableFailure
+transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
+ LinalgOp target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ // Exit early if the operation is not a generic.
+ if (!isa<GenericOp>(target)) {
+ results.push_back(target);
+ return DiagnosedSilenceableFailure::success();
+ }
+ rewriter.setInsertionPoint(target);
+ FailureOr<LinalgOp> named =
+ specializeGenericOp(rewriter, cast<GenericOp>(target));
+ if (succeeded(named)) {
+ results.push_back(named->getOperation());
+ return DiagnosedSilenceableFailure::success();
+ }
+ return emitDefaultSilenceableFailure(target);
+}
+
//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4e094609afa6a03..5ec1fd5dd7e91db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
NamedOpConversions.cpp
Padding.cpp
Promotion.cpp
+ Specialize.cpp
Split.cpp
SplitReduction.cpp
SubsetHoisting.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
new file mode 100644
index 000000000000000..6c7be63069dad1d
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -0,0 +1,52 @@
+//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a method to specialize generic operations to named
+// operations. Conceptually it is the opposite of generalize.cpp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-specialization"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static bool isaCopyOp(GenericOp genericOp) {
+ // Structural.
+ if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
+ return false;
+
+ // Operands and maps.
+ if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+ return false;
+ auto mapRange = genericOp.getIndexingMapsArray();
+ if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
+ !mapRange.back().isIdentity()) {
+ return false;
+ }
+
+ // Region.
+ Region ® = genericOp.getRegion();
+ if (!llvm::hasSingleElement(reg))
+ return false;
+ return std::distance(reg.front().begin(), reg.front().end()) == 1;
+}
+
+FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ if (isaCopyOp(genericOp)) {
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+ return namedOp;
+ }
+ return failure();
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
new file mode 100644
index 000000000000000..a125d2dc3ca29e6
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
+ // expected-note @below {{when applied to this op}}
+ linalg.generic {
+ indexing_maps = [#map1, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ }
+ return
+}
+
+func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+ // expected-note @below {{when applied to this op}}
+ linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %0 = arith.addf %in, %out : f32
+ linalg.yield %0 : f32
+ }
+ return
+}
+
+func.func @transpose_op_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+ // expected-note @below {{when applied to this op}}
+ linalg.generic {
+ indexing_maps = [#map, #map2],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{failed to apply}}
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: generalize_trivial_copy
+func.func @generalize_trivial_copy(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+ // CHECK: linalg.copy
+ // CHECK-NOT: linalg.generic
+ linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very useful transform, well done!
if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) | ||
return false; | ||
auto mapRange = genericOp.getIndexingMapsArray(); | ||
if (mapRange.size() != 2 || !mapRange.front().isIdentity() || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I came across AffineMaps which are actually identity maps if we take into consideration that the shape of the operands has 1 in it (this actually happens a lot when lowering from Tosa to Linalg), for example:
#map = affine_map<(d0, d1) -> (0, d1)>
Where operands shapes are <1x2xf32>
For those cases, I think it would be useful to have a util functions:
bool isIndexingMapsRepresentIdentity(linalg::GenericOp genericOp) {
return llvm::all_of(
genericOp->getOpOperands(), [&genericOp](OpOperand &opOperand) {
return isCanonicalizedIdentityMap(
genericOp.getMatchingIndexingMap(&opOperand), opOperand);
});
}
} // namespace gcx::LinalgUtils
static bool isCanonicalizedIdentityMap(AffineMap map, OpOperand &opOperand) {
if (map.getNumDims() != map.getNumResults())
return false;
ArrayRef<AffineExpr> results = map.getResults();
auto shape = cast<ShapedType>(opOperand.get().getType()).getShape();
for (unsigned i = 0, numDims = map.getNumDims(); i < numDims; ++i) {
auto constExpr = results[i].dyn_cast<AffineConstantExpr>();
if (constExpr && constExpr.getValue() == 0 && shape[0] == 1)
continue;
auto expr = results[i].dyn_cast<AffineDimExpr>();
if (!expr || expr.getPosition() != i)
return false;
}
return true;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. We could add isCanonicalizedIdentityMap
, but linalg.copy does not support broadcast. What use cases do you have in mind? Should we start a small working group around this topic? Linalg specialization (aka linalg.generic -> linalg.named ops) has been discussed multiple times, but no concrete steps have been taken in this direction. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my concrete exmaple I have faced, I lowered tosa.add
to linalg (so it is not exactly lowered into linalg.copy
):
func.func private @CustomAddLayer_kernel(%arg0: tensor<1x24x32x512xf32>, %arg1: tensor<1x24x32x512xf32>) -> (tensor<1x24x32x512xf32>) {
%0 = tosa.add %arg0, %arg1 : (tensor<1x24x32x512xf32>, tensor<1x24x32x512xf32>) -> tensor<1x24x32x512xf32>
return %0 : tensor<1x24x32x512xf32>
}
// -----// IR Dump After TosaToLinalg (tosa-to-linalg) //----- //
func.func private @CustomAddLayer_kernel(%arg0: tensor<1x24x32x512xf32>, %arg1: tensor<1x24x32x512xf32>) -> (tensor<1x24x32x512xf32>) {
%0 = tensor.empty() : tensor<1x24x32x512xf32>
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<1x24x32x512xf32>, tensor<1x24x32x512xf32>) outs(%0 : tensor<1x24x32x512xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = arith.addf %in, %in_0 : f32
linalg.yield %2 : f32
} -> tensor<1x24x32x512xf32>
return %1 : tensor<1x24x32x512xf32>
}
In this example, we actually use some CanonicalizedIdentityMap
in the linalg.generic
indexing maps.
This is actually happens a lot when lowering from Tosa to Linalg, so I believe it would be useful to have such function in the LinalgOp
interface and use it also here. If you agree, it can also be in different patch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @AviadCo. Does FoldUnitExtentDims
work for you? We could run the transform right after FoldUnitExtentDims
.
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
|
||
This operation ignores non-Linalg ops and drops them in the return. | ||
If all the operations referred to by the `target` handle specialize | ||
properly, the transform succeeds. Otherwise the transform silently fails. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: it's probably my fault in naming, but please don't use "silently fails". The failure is not silent, it's silenceable. It will be reported by default unless the surrounding context requests that otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, updated!
succeeds; otherwise, the operation produces a silenceable failure. The return | ||
handle points to only the subset of successfully produced equivalent named | ||
operations, which can be empty or contain the original ops if they were already | ||
in named form. Only linalg.copy specialization is available, but more will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rephrase the last sentence in a new paragraph with an itemized list that we can expand in the future?
The supported specializations to named Linalg operations are:
- linalg.copy of any rank
Introduce an operation to specialize linalg.generics, for example, detecting a linalg.generic that is semantically equivalent to a linalg.copy and replacing the former with the latter. After code generation, it is helpful to lower named operations to vendor-optimized libraries.
Introduce an operation to specialize linalg.generics, for example, detecting a linalg.generic that is semantically equivalent to a linalg.copy and replacing the former with the latter. After code generation, it is helpful to lower named operations to vendor-optimized libraries.