-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] introduce transform.num_associations #76723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Oleksandr "Alex" Zinenko (ftynse) ChangesAdd a new transform operation that creates a new parameter containing the number of payload objects (operations, values or attributes) associated with the argument. This is useful in matching and for debugging purposes. This replaces three ad-hoc operations previously provided by the test extension. Patch is 27.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76723.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 307257f4a582be..da0162faa6e466 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -438,6 +438,27 @@ def CastOp : TransformDialectOp<"cast",
}];
}
+def NumAssociationsOp : TransformDialectOp<"num_associations",
+ [MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ MatchOpInterface]> {
+ let summary =
+ "Returns the number of payload objects associated with the argument";
+ let description = [{
+ Given an argument, handle or parameter, returns a new parameter associated
+ with a single 64-bit number that corresponds to the number of payload
+ objects (operations or values for a handle, attributes for a parameter)
+ associated with the argument.
+
+ Always succeeds.
+ }];
+ let arguments = (ins Transform_AnyHandleOrParamType:$handle);
+ let results = (outs TransformParamTypeInterface:$num);
+ let assemblyFormat = [{
+ $handle attr-dict `:` functional-type(operands, results)
+ }];
+}
+
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7136e423470a28..ca644252f3514a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -32,6 +32,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -1974,6 +1975,34 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
}
+//===----------------------------------------------------------------------===//
+// NumAssociationsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ size_t numAssociations =
+ llvm::TypeSwitch<Type, size_t>(getHandle().getType())
+ .Case([&](TransformHandleTypeInterface opHandle) {
+ return llvm::range_size(state.getPayloadOps(getHandle()));
+ })
+ .Case([&](TransformValueHandleTypeInterface valueHandle) {
+ return llvm::range_size(state.getPayloadValues(getHandle()));
+ })
+ .Case([&](TransformParamTypeInterface param) {
+ return llvm::range_size(state.getParams(getHandle()));
+ })
+ .Default([](Type) {
+ llvm_unreachable("unknown kind of transform dialect type");
+ return 0;
+ });
+ results.setParams(getNum().cast<OpResult>(),
+ rewriter.getI64IntegerAttr(numAssociations));
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
index 49a52ba9e06f86..aa15ccf0beeee2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -36,13 +36,15 @@ module attributes {transform.with_named_sequence} {
// Ensure that one linalg.fill was generated.
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+ %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
// Ensure that one linalg.copy was generated.
%mat = transform.select "bufferization.materialize_in_destination" in %new : (!transform.any_op) -> !transform.any_op
+ %p2 = transform.num_associations %mat : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %mat : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
transform.yield
}
}
@@ -73,18 +75,21 @@ module attributes {transform.with_named_sequence} {
// Ensure that one linalg.fill was generated.
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+ %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
// Ensure that one linalg.copy was generated.
%linalg_copy = transform.select "linalg.copy" in %new : (!transform.any_op) -> !transform.any_op
+ %p2 = transform.num_associations %linalg_copy : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %linalg_copy : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
// Ensure that one memref.alloca was generated.
%alloca = transform.select "memref.alloca" in %new : (!transform.any_op) -> !transform.any_op
+ %p3 = transform.num_associations %alloca : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %alloca : !transform.any_op
+ transform.test_print_param %p3 : !transform.param<i64>
// Make sure that One-Shot Bufferize can bufferize the rest.
%4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 15942db9b5db20..db5b5f1c786776 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -134,8 +134,9 @@ module attributes {transform.with_named_sequence} {
#linalg.iterator_type<parallel>,
#linalg.iterator_type<reduction>]}
in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %no_match : !transform.any_op
+ %p = transform.num_associations %no_match : (!transform.any_op) -> !transform.param<i64>
+ // expected-remark @below {{0}}
+ transform.test_print_param %p : !transform.param<i64>
transform.yield
}
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index 6bca6c1fd6bf12..1f9d81a819e7fb 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -41,8 +41,9 @@ module attributes {transform.with_named_sequence} {
padding_dimensions=[0, 1, 2],
pack_paddings=[1, 1, 0]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.materialize_in_destination">)
+ %p = transform.num_associations %copy_back : (!transform.op<"bufferization.materialize_in_destination">) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.materialize_in_destination">
+ transform.test_print_param %p : !transform.param<i64>
transform.yield
}
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index d9a11994eb9d90..a39e6f94cb34f6 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -575,8 +575,9 @@ transform.with_pdl_patterns {
%0 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
%2 = merge_handles deduplicate %0, %1 : !transform.any_op
+ %3 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ test_print_param %3 : !transform.param<i64>
}
}
@@ -676,11 +677,13 @@ module {
^bb0(%arg1: !transform.any_op):
%0 = pdl_match @func in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = replicate num(%0) %arg1 : !transform.any_op, !transform.any_op
+ %p = num_associations %1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
%2 = replicate num(%0) %1 : !transform.any_op, !transform.any_op
+ %p2 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ test_print_param %p2 : !transform.param<i64>
}
}
}
@@ -708,8 +711,9 @@ transform.with_pdl_patterns {
%f = pdl_match @const in %arg1 : (!transform.any_op) -> !transform.any_op
transform.foreach %f : !transform.any_op {
^bb2(%arg2: !transform.any_op):
+ %p = transform.num_associations %arg2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %arg2 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
transform.test_print_remark_at_operand %arg2, "transform applied" : !transform.any_op
}
}
@@ -780,8 +784,9 @@ transform.with_pdl_patterns {
transform.yield %g : !transform.any_op
}
+ %p = transform.num_associations %results : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{3}}
- transform.test_print_number_of_associated_payload_ir_ops %results : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
transform.test_print_remark_at_operand %results, "transform applied" : !transform.any_op
}
}
@@ -877,8 +882,9 @@ transform.sequence failures(propagate) {
^bb1(%fun: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
%h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@@ -896,13 +902,15 @@ transform.sequence failures(suppress) {
^bb1(%fun: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// Silenceable failure and all handles are now empty.
%h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %p2 = transform.num_associations %h_2#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %h_2#0 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
}
// -----
@@ -918,12 +926,15 @@ transform.sequence failures(propagate) {
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// No error, last result handle is empty.
%h:3 = split_handle %muli_2 {fail_on_payload_too_small = false} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
+ %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
+ %p3 = transform.num_associations %h#2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %h#2 : !transform.any_op
+ transform.test_print_param %p3 : !transform.param<i64>
}
// -----
@@ -940,10 +951,12 @@ transform.sequence failures(propagate) {
^bb1(%fun: !transform.any_op):
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli_2 {overflow_result = 0} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{3}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
+ %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
}
// -----
@@ -1668,8 +1681,9 @@ transform.sequence failures(propagate) {
// expected-remark @below {{2 iterations}}
transform.test_tracked_rewrite %0 : (!transform.any_op) -> ()
// One replacement op (test.drop_mapping) is dropped from the mapping.
+ %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
}
// -----
@@ -1684,20 +1698,24 @@ module {
%2 = transform.param.constant 1 -> !transform.param<i64>
%3 = transform.param.constant 2 -> !transform.param<i64>
%4 = transform.merge_handles %1, %2 { deduplicate } : !transform.param<i64>
+ %p = num_associations %4 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_params %4 : !transform.param<i64>
+ test_print_param %p : !transform.param<i64>
%5 = transform.merge_handles %1, %1 { deduplicate } : !transform.param<i64>
+ %p2 = num_associations %5 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_params %5 : !transform.param<i64>
+ test_print_param %p2 : !transform.param<i64>
%6 = transform.merge_handles %1, %3 { deduplicate } : !transform.param<i64>
+ %p3 = num_associations %6 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_params %6 : !transform.param<i64>
+ test_print_param %p3 : !transform.param<i64>
%7 = transform.merge_handles %1, %1, %2, %3 : !transform.param<i64>
+ %p4 = num_associations %7 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_params %7 : !transform.param<i64>
+ test_print_param %p4 : !transform.param<i64>
}
}
@@ -1712,21 +1730,25 @@ transform.sequence failures(propagate) {
%3 = test_produce_value_handle_to_result %1, 1 : (!transform.any_op) -> !transform.any_value
%4 = transform.merge_handles %2, %2 { deduplicate } : !transform.any_value
+ %p = num_associations %4 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_values %4 : !transform.any_value
+ test_print_param %p : !transform.param<i64>
%5 = transform.merge_handles %2, %3 { deduplicate } : !transform.any_value
+ %p2 = num_associations %5 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_values %5 : !transform.any_value
+ test_print_param %p2 : !transform.param<i64>
%6 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value
%7 = transform.merge_handles %2, %6 { deduplicate } : !transform.any_value
+ %p3 = num_associations %6 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_values %6 : !transform.any_value
+ test_print_param %p3 : !transform.param<i64>
%8 = transform.merge_handles %2, %2, %3, %4 : !transform.any_value
+ %p4 = num_associations %8 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_values %8 : !transform.any_value
+ test_print_param %p4 : !transform.param<i64>
}
// -----
@@ -1820,31 +1842,37 @@ transform.sequence failures(propagate) {
// There are 3 arith.constant ops.
%all = transform.structured.match ops{["arith.constant"]} in %0 : (!transform.any_op) -> !transform.any_op
+ %p = num_associations %all : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+ test_print_param %p : !transform.param<i64>
// "deduplicate" has no effect because these are 3 different ops.
%merged_before = transform.merge_handles deduplicate %all : !transform.any_op
+ %p2 = num_associations %merged_before : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %merged_before : !transform.any_op
+ test_print_param %p2 : !transform.param<i64>
// Apply CSE.
transform.apply_cse to %0 : !transform.any_op
// The handle is still mapped to 3 arith.constant ops.
+ %p3 = num_associations %all : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+ test_print_param %p3 : !transform.param<i64>
// But they are all the same op.
%merged_after = transform.merge_handles deduplicate %all : !transform.any_op
+ %p4 = num_associations %merged_after : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %merged_after : !transform.any_op
+ test_print_param %p4 : !transform.param<i64>
// The other handles were also updated.
test_print_remark_at_operand %elim_first, "eliminated 1" : !transform.any_op
+ %p5 = num_associations %elim_first : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %elim_first : !transform.any_op
+ test_print_param %p5 : !transform.param<i64>
test_print_remark_at_operand %elim_second, "eliminated 2" : !transform.any_op
+ %p6 = num_associations %elim_second : (!transform.any_op) -> !transf...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth adding a verifier that the result type is either param<i64>
or any_param
? Right now this lets users write
%p = transform.num_associations %op : (!transform.any_op) -> !transform.param<i32>
which will fail during application.
Good point, we can just do the same check at verification time rather than at interpretation time. |
Add a new transform operation that creates a new parameter containing the number of payload objects (operations, values or attributes) associated with the argument. This is useful in matching and for debugging purposes. This replaces three ad-hoc operations previously provided by the test extension.
38942f6
to
43c73f2
Compare
Add a new transform operation that creates a new parameter containing the number of payload objects (operations, values or attributes) associated with the argument. This is useful in matching and for debugging purposes. This replaces three ad-hoc operations previously provided by the test extension.